diff options
Diffstat (limited to 'pkg')
982 files changed, 90975 insertions, 129744 deletions
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD deleted file mode 100644 index 839f822eb..000000000 --- a/pkg/abi/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "abi", - srcs = [ - "abi.go", - "abi_linux.go", - "flag.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/abi/abi_linux_state_autogen.go b/pkg/abi/abi_linux_state_autogen.go new file mode 100644 index 000000000..8a1390439 --- /dev/null +++ b/pkg/abi/abi_linux_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package abi diff --git a/pkg/abi/abi_state_autogen.go b/pkg/abi/abi_state_autogen.go new file mode 100644 index 000000000..d54002c3b --- /dev/null +++ b/pkg/abi/abi_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package abi diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD deleted file mode 100644 index 3576396c1..000000000 --- a/pkg/abi/linux/BUILD +++ /dev/null @@ -1,95 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -# Package linux contains the constants and types needed to interface with a -# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix -# when the host OS may not be Linux. - -package(licenses = ["notice"]) - -go_library( - name = "linux", - srcs = [ - "aio.go", - "arch_amd64.go", - "audit.go", - "bpf.go", - "capability.go", - "clone.go", - "context.go", - "dev.go", - "elf.go", - "epoll.go", - "epoll_amd64.go", - "epoll_arm64.go", - "errqueue.go", - "eventfd.go", - "exec.go", - "fadvise.go", - "fcntl.go", - "file.go", - "file_amd64.go", - "file_arm64.go", - "fs.go", - "fuse.go", - "futex.go", - "inotify.go", - "ioctl.go", - "ioctl_tun.go", - "ip.go", - "ipc.go", - "limits.go", - "linux.go", - "membarrier.go", - "mm.go", - "msgqueue.go", - "netdevice.go", - "netfilter.go", - "netfilter_ipv6.go", - "netlink.go", - "netlink_route.go", - "poll.go", - "prctl.go", - "ptrace.go", - "ptrace_amd64.go", - "ptrace_arm64.go", - "rseq.go", - "rusage.go", - "sched.go", - "seccomp.go", - "sem.go", - "sem_amd64.go", - "sem_arm64.go", - "shm.go", - "signal.go", - "signalfd.go", - "socket.go", - "splice.go", - "tcp.go", - "time.go", - "timer.go", - "tty.go", - "uio.go", - "utsname.go", - "wait.go", - "xattr.go", - ], - marshal = True, - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi", - "//pkg/bits", - "//pkg/context", - "//pkg/hostarch", - "//pkg/marshal", - "//pkg/marshal/primitive", - ], -) - -go_test( - name = "linux_test", - size = "small", - srcs = [ - "netfilter_test.go", - ], - library = ":linux", -) diff --git a/pkg/abi/linux/errno/BUILD b/pkg/abi/linux/errno/BUILD deleted file mode 100644 index d003342d5..000000000 --- a/pkg/abi/linux/errno/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "errno", - srcs = ["errno.go"], - visibility = ["//visibility:public"], -) diff --git a/pkg/abi/linux/errno/errno_state_autogen.go b/pkg/abi/linux/errno/errno_state_autogen.go new file mode 100644 index 000000000..4c4ae64a6 --- /dev/null +++ b/pkg/abi/linux/errno/errno_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package errno diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go new file mode 100644 index 000000000..11cf350d2 --- /dev/null +++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go @@ -0,0 +1,15222 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*BPFInstruction)(nil) +var _ marshal.Marshallable = (*CapUserData)(nil) +var _ marshal.Marshallable = (*CapUserHeader)(nil) +var _ marshal.Marshallable = (*ClockT)(nil) +var _ marshal.Marshallable = (*ControlMessageCredentials)(nil) +var _ marshal.Marshallable = (*ControlMessageHeader)(nil) +var _ marshal.Marshallable = (*ControlMessageIPPacketInfo)(nil) +var _ marshal.Marshallable = (*DigestMetadata)(nil) +var _ marshal.Marshallable = (*ElfHeader64)(nil) +var _ marshal.Marshallable = (*ElfProg64)(nil) +var _ marshal.Marshallable = (*ElfSection64)(nil) +var _ marshal.Marshallable = (*ErrorName)(nil) +var _ marshal.Marshallable = (*ExtensionName)(nil) +var _ marshal.Marshallable = (*FOwnerEx)(nil) +var _ marshal.Marshallable = (*FUSEAttr)(nil) +var _ marshal.Marshallable = (*FUSECreateMeta)(nil) +var _ marshal.Marshallable = (*FUSEDirentMeta)(nil) +var _ marshal.Marshallable = (*FUSEEntryOut)(nil) +var _ marshal.Marshallable = (*FUSEGetAttrIn)(nil) +var _ marshal.Marshallable = (*FUSEGetAttrOut)(nil) +var _ marshal.Marshallable = (*FUSEHeaderIn)(nil) +var _ marshal.Marshallable = (*FUSEHeaderOut)(nil) +var _ marshal.Marshallable = (*FUSEInitIn)(nil) +var _ marshal.Marshallable = (*FUSEInitOut)(nil) +var _ marshal.Marshallable = (*FUSEMkdirMeta)(nil) +var _ marshal.Marshallable = (*FUSEMknodMeta)(nil) +var _ marshal.Marshallable = (*FUSEOpID)(nil) +var _ marshal.Marshallable = (*FUSEOpcode)(nil) +var _ marshal.Marshallable = (*FUSEOpenIn)(nil) +var _ marshal.Marshallable = (*FUSEOpenOut)(nil) +var _ marshal.Marshallable = (*FUSEReadIn)(nil) +var _ marshal.Marshallable = (*FUSEReleaseIn)(nil) +var _ marshal.Marshallable = (*FUSESetAttrIn)(nil) +var _ marshal.Marshallable = (*FUSEWriteIn)(nil) +var _ marshal.Marshallable = (*FUSEWriteOut)(nil) +var _ marshal.Marshallable = (*Flock)(nil) +var _ marshal.Marshallable = (*IFConf)(nil) +var _ marshal.Marshallable = (*IFReq)(nil) +var _ marshal.Marshallable = (*IOCallback)(nil) +var _ marshal.Marshallable = (*IOEvent)(nil) +var _ marshal.Marshallable = (*IP6TEntry)(nil) +var _ marshal.Marshallable = (*IP6TIP)(nil) +var _ marshal.Marshallable = (*IP6TReplace)(nil) +var _ marshal.Marshallable = (*IPCPerm)(nil) +var _ marshal.Marshallable = (*IPTEntry)(nil) +var _ marshal.Marshallable = (*IPTGetEntries)(nil) +var _ marshal.Marshallable = (*IPTGetinfo)(nil) +var _ marshal.Marshallable = (*IPTIP)(nil) +var _ marshal.Marshallable = (*IPTOwnerInfo)(nil) +var _ marshal.Marshallable = (*IPTReplace)(nil) +var _ marshal.Marshallable = (*Inet6Addr)(nil) +var _ marshal.Marshallable = (*Inet6MulticastRequest)(nil) +var _ marshal.Marshallable = (*InetAddr)(nil) +var _ marshal.Marshallable = (*InetMulticastRequest)(nil) +var _ marshal.Marshallable = (*InetMulticastRequestWithNIC)(nil) +var _ marshal.Marshallable = (*InterfaceAddrMessage)(nil) +var _ marshal.Marshallable = (*InterfaceInfoMessage)(nil) +var _ marshal.Marshallable = (*ItimerVal)(nil) +var _ marshal.Marshallable = (*Itimerspec)(nil) +var _ marshal.Marshallable = (*KernelIP6TEntry)(nil) +var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil) +var _ marshal.Marshallable = (*KernelIPTEntry)(nil) +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) +var _ marshal.Marshallable = (*Linger)(nil) +var _ marshal.Marshallable = (*MsgBuf)(nil) +var _ marshal.Marshallable = (*MsgInfo)(nil) +var _ marshal.Marshallable = (*MsqidDS)(nil) +var _ marshal.Marshallable = (*NFNATRange)(nil) +var _ marshal.Marshallable = (*NetlinkAttrHeader)(nil) +var _ marshal.Marshallable = (*NetlinkErrorMessage)(nil) +var _ marshal.Marshallable = (*NetlinkMessageHeader)(nil) +var _ marshal.Marshallable = (*NfNATIPV4MultiRangeCompat)(nil) +var _ marshal.Marshallable = (*NfNATIPV4Range)(nil) +var _ marshal.Marshallable = (*NumaPolicy)(nil) +var _ marshal.Marshallable = (*PollFD)(nil) +var _ marshal.Marshallable = (*RSeqCriticalSection)(nil) +var _ marshal.Marshallable = (*RobustListHead)(nil) +var _ marshal.Marshallable = (*RouteMessage)(nil) +var _ marshal.Marshallable = (*Rusage)(nil) +var _ marshal.Marshallable = (*SeccompData)(nil) +var _ marshal.Marshallable = (*SemInfo)(nil) +var _ marshal.Marshallable = (*Sembuf)(nil) +var _ marshal.Marshallable = (*ShmInfo)(nil) +var _ marshal.Marshallable = (*ShmParams)(nil) +var _ marshal.Marshallable = (*ShmidDS)(nil) +var _ marshal.Marshallable = (*SigAction)(nil) +var _ marshal.Marshallable = (*Sigevent)(nil) +var _ marshal.Marshallable = (*SignalInfo)(nil) +var _ marshal.Marshallable = (*SignalSet)(nil) +var _ marshal.Marshallable = (*SignalStack)(nil) +var _ marshal.Marshallable = (*SignalfdSiginfo)(nil) +var _ marshal.Marshallable = (*SockAddrInet)(nil) +var _ marshal.Marshallable = (*SockAddrInet6)(nil) +var _ marshal.Marshallable = (*SockAddrLink)(nil) +var _ marshal.Marshallable = (*SockAddrNetlink)(nil) +var _ marshal.Marshallable = (*SockAddrUnix)(nil) +var _ marshal.Marshallable = (*SockErrCMsgIPv4)(nil) +var _ marshal.Marshallable = (*SockErrCMsgIPv6)(nil) +var _ marshal.Marshallable = (*SockExtendedErr)(nil) +var _ marshal.Marshallable = (*Statfs)(nil) +var _ marshal.Marshallable = (*Statx)(nil) +var _ marshal.Marshallable = (*StatxTimestamp)(nil) +var _ marshal.Marshallable = (*Sysinfo)(nil) +var _ marshal.Marshallable = (*TCPInfo)(nil) +var _ marshal.Marshallable = (*TableName)(nil) +var _ marshal.Marshallable = (*Termios)(nil) +var _ marshal.Marshallable = (*TimeT)(nil) +var _ marshal.Marshallable = (*TimerID)(nil) +var _ marshal.Marshallable = (*Timespec)(nil) +var _ marshal.Marshallable = (*Timeval)(nil) +var _ marshal.Marshallable = (*Tms)(nil) +var _ marshal.Marshallable = (*Utime)(nil) +var _ marshal.Marshallable = (*UtsName)(nil) +var _ marshal.Marshallable = (*WindowSize)(nil) +var _ marshal.Marshallable = (*Winsize)(nil) +var _ marshal.Marshallable = (*XTCounters)(nil) +var _ marshal.Marshallable = (*XTEntryMatch)(nil) +var _ marshal.Marshallable = (*XTEntryTarget)(nil) +var _ marshal.Marshallable = (*XTErrorTarget)(nil) +var _ marshal.Marshallable = (*XTGetRevision)(nil) +var _ marshal.Marshallable = (*XTRedirectTarget)(nil) +var _ marshal.Marshallable = (*XTSNATTarget)(nil) +var _ marshal.Marshallable = (*XTStandardTarget)(nil) +var _ marshal.Marshallable = (*XTTCP)(nil) +var _ marshal.Marshallable = (*XTUDP)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IOCallback) SizeBytes() int { + return 64 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IOCallback) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Data)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Key)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.OpCode)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.ReqPrio)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.FD)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Buf)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Bytes)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Offset)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Reserved2)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.ResFD)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IOCallback) UnmarshalBytes(src []byte) { + i.Data = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Key = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + i.OpCode = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.ReqPrio = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.FD = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Buf = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Bytes = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Offset = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Reserved2 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.ResFD = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IOCallback) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IOCallback) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IOCallback) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IOCallback) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IOCallback) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IOCallback) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IOCallback) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IOEvent) SizeBytes() int { + return 32 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IOEvent) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Data)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Obj)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Result)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Result2)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IOEvent) UnmarshalBytes(src []byte) { + i.Data = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Obj = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Result = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.Result2 = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IOEvent) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IOEvent) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IOEvent) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IOEvent) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IOEvent) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IOEvent) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IOEvent) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (b *BPFInstruction) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (b *BPFInstruction) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(b.OpCode)) + dst = dst[2:] + dst[0] = byte(b.JumpIfTrue) + dst = dst[1:] + dst[0] = byte(b.JumpIfFalse) + dst = dst[1:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(b.K)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (b *BPFInstruction) UnmarshalBytes(src []byte) { + b.OpCode = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + b.JumpIfTrue = uint8(src[0]) + src = src[1:] + b.JumpIfFalse = uint8(src[0]) + src = src[1:] + b.K = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (b *BPFInstruction) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *BPFInstruction) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(b), uintptr(b.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *BPFInstruction) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(b), unsafe.Pointer(&src[0]), uintptr(b.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (b *BPFInstruction) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(b))) + hdr.Len = b.SizeBytes() + hdr.Cap = b.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that b + // must live until the use above. + runtime.KeepAlive(b) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (b *BPFInstruction) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return b.CopyOutN(cc, addr, b.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (b *BPFInstruction) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(b))) + hdr.Len = b.SizeBytes() + hdr.Cap = b.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that b + // must live until the use above. + runtime.KeepAlive(b) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *BPFInstruction) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(b))) + hdr.Len = b.SizeBytes() + hdr.Cap = b.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that b + // must live until the use above. + runtime.KeepAlive(b) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyBPFInstructionSliceIn copies in a slice of BPFInstruction objects from the task's memory. +func CopyBPFInstructionSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []BPFInstruction) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*BPFInstruction)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyBPFInstructionSliceOut copies a slice of BPFInstruction objects to the task's memory. +func CopyBPFInstructionSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []BPFInstruction) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*BPFInstruction)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeBPFInstructionSlice is like BPFInstruction.MarshalUnsafe, but for a []BPFInstruction. +func MarshalUnsafeBPFInstructionSlice(src []BPFInstruction, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*BPFInstruction)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeBPFInstructionSlice is like BPFInstruction.UnmarshalUnsafe, but for a []BPFInstruction. +func UnmarshalUnsafeBPFInstructionSlice(dst []BPFInstruction, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*BPFInstruction)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *CapUserData) SizeBytes() int { + return 12 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *CapUserData) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Effective)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Permitted)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Inheritable)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *CapUserData) UnmarshalBytes(src []byte) { + c.Effective = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.Permitted = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.Inheritable = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (c *CapUserData) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (c *CapUserData) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(c), uintptr(c.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (c *CapUserData) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(c), unsafe.Pointer(&src[0]), uintptr(c.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (c *CapUserData) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (c *CapUserData) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return c.CopyOutN(cc, addr, c.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (c *CapUserData) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (c *CapUserData) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyCapUserDataSliceIn copies in a slice of CapUserData objects from the task's memory. +func CopyCapUserDataSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []CapUserData) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*CapUserData)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyCapUserDataSliceOut copies a slice of CapUserData objects to the task's memory. +func CopyCapUserDataSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []CapUserData) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*CapUserData)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeCapUserDataSlice is like CapUserData.MarshalUnsafe, but for a []CapUserData. +func MarshalUnsafeCapUserDataSlice(src []CapUserData, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*CapUserData)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeCapUserDataSlice is like CapUserData.UnmarshalUnsafe, but for a []CapUserData. +func UnmarshalUnsafeCapUserDataSlice(dst []CapUserData, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*CapUserData)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *CapUserHeader) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *CapUserHeader) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Version)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Pid)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *CapUserHeader) UnmarshalBytes(src []byte) { + c.Version = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.Pid = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (c *CapUserHeader) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (c *CapUserHeader) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(c), uintptr(c.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (c *CapUserHeader) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(c), unsafe.Pointer(&src[0]), uintptr(c.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (c *CapUserHeader) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (c *CapUserHeader) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return c.CopyOutN(cc, addr, c.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (c *CapUserHeader) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (c *CapUserHeader) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *ElfHeader64) SizeBytes() int { + return 48 + + 1*16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *ElfHeader64) MarshalBytes(dst []byte) { + for idx := 0; idx < 16; idx++ { + dst[0] = byte(e.Ident[idx]) + dst = dst[1:] + } + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Type)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Machine)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Version)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Entry)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Phoff)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Shoff)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Ehsize)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Phentsize)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Phnum)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Shentsize)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Shnum)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(e.Shstrndx)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *ElfHeader64) UnmarshalBytes(src []byte) { + for idx := 0; idx < 16; idx++ { + e.Ident[idx] = src[0] + src = src[1:] + } + e.Type = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Machine = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Version = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Entry = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Phoff = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Shoff = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Ehsize = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Phentsize = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Phnum = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Shentsize = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Shnum = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + e.Shstrndx = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (e *ElfHeader64) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *ElfHeader64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(e), uintptr(e.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *ElfHeader64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(e), unsafe.Pointer(&src[0]), uintptr(e.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (e *ElfHeader64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (e *ElfHeader64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return e.CopyOutN(cc, addr, e.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (e *ElfHeader64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *ElfHeader64) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *ElfProg64) SizeBytes() int { + return 56 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *ElfProg64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Type)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Off)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Vaddr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Paddr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Filesz)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Memsz)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Align)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *ElfProg64) UnmarshalBytes(src []byte) { + e.Type = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Off = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Vaddr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Paddr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Filesz = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Memsz = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Align = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (e *ElfProg64) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *ElfProg64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(e), uintptr(e.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *ElfProg64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(e), unsafe.Pointer(&src[0]), uintptr(e.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (e *ElfProg64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (e *ElfProg64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return e.CopyOutN(cc, addr, e.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (e *ElfProg64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *ElfProg64) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *ElfSection64) SizeBytes() int { + return 64 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *ElfSection64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Name)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Type)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Flags)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Addr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Off)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Size)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Link)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Info)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Addralign)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(e.Entsize)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *ElfSection64) UnmarshalBytes(src []byte) { + e.Name = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Type = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Flags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Addr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Off = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Size = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Link = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Info = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + e.Addralign = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + e.Entsize = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (e *ElfSection64) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *ElfSection64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(e), uintptr(e.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *ElfSection64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(e), unsafe.Pointer(&src[0]), uintptr(e.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (e *ElfSection64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (e *ElfSection64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return e.CopyOutN(cc, addr, e.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (e *ElfSection64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *ElfSection64) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockErrCMsgIPv4) SizeBytes() int { + return 0 + + (*SockExtendedErr)(nil).SizeBytes() + + (*SockAddrInet)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockErrCMsgIPv4) MarshalBytes(dst []byte) { + s.SockExtendedErr.MarshalBytes(dst[:s.SockExtendedErr.SizeBytes()]) + dst = dst[s.SockExtendedErr.SizeBytes():] + s.Offender.MarshalBytes(dst[:s.Offender.SizeBytes()]) + dst = dst[s.Offender.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockErrCMsgIPv4) UnmarshalBytes(src []byte) { + s.SockExtendedErr.UnmarshalBytes(src[:s.SockExtendedErr.SizeBytes()]) + src = src[s.SockExtendedErr.SizeBytes():] + s.Offender.UnmarshalBytes(src[:s.Offender.SizeBytes()]) + src = src[s.Offender.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockErrCMsgIPv4) Packed() bool { + return s.Offender.Packed() && s.SockExtendedErr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockErrCMsgIPv4) MarshalUnsafe(dst []byte) { + if s.Offender.Packed() && s.SockExtendedErr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SockErrCMsgIPv4 doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockErrCMsgIPv4) UnmarshalUnsafe(src []byte) { + if s.Offender.Packed() && s.SockExtendedErr.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SockErrCMsgIPv4 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockErrCMsgIPv4) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Offender.Packed() && s.SockExtendedErr.Packed() { + // Type SockErrCMsgIPv4 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockErrCMsgIPv4) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockErrCMsgIPv4) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Offender.Packed() && s.SockExtendedErr.Packed() { + // Type SockErrCMsgIPv4 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockErrCMsgIPv4) WriteTo(writer io.Writer) (int64, error) { + if !s.Offender.Packed() && s.SockExtendedErr.Packed() { + // Type SockErrCMsgIPv4 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockErrCMsgIPv6) SizeBytes() int { + return 0 + + (*SockExtendedErr)(nil).SizeBytes() + + (*SockAddrInet6)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockErrCMsgIPv6) MarshalBytes(dst []byte) { + s.SockExtendedErr.MarshalBytes(dst[:s.SockExtendedErr.SizeBytes()]) + dst = dst[s.SockExtendedErr.SizeBytes():] + s.Offender.MarshalBytes(dst[:s.Offender.SizeBytes()]) + dst = dst[s.Offender.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockErrCMsgIPv6) UnmarshalBytes(src []byte) { + s.SockExtendedErr.UnmarshalBytes(src[:s.SockExtendedErr.SizeBytes()]) + src = src[s.SockExtendedErr.SizeBytes():] + s.Offender.UnmarshalBytes(src[:s.Offender.SizeBytes()]) + src = src[s.Offender.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockErrCMsgIPv6) Packed() bool { + return s.Offender.Packed() && s.SockExtendedErr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockErrCMsgIPv6) MarshalUnsafe(dst []byte) { + if s.Offender.Packed() && s.SockExtendedErr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SockErrCMsgIPv6 doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockErrCMsgIPv6) UnmarshalUnsafe(src []byte) { + if s.Offender.Packed() && s.SockExtendedErr.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SockErrCMsgIPv6 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockErrCMsgIPv6) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Offender.Packed() && s.SockExtendedErr.Packed() { + // Type SockErrCMsgIPv6 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockErrCMsgIPv6) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockErrCMsgIPv6) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Offender.Packed() && s.SockExtendedErr.Packed() { + // Type SockErrCMsgIPv6 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockErrCMsgIPv6) WriteTo(writer io.Writer) (int64, error) { + if !s.Offender.Packed() && s.SockExtendedErr.Packed() { + // Type SockErrCMsgIPv6 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockExtendedErr) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockExtendedErr) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Errno)) + dst = dst[4:] + dst[0] = byte(s.Origin) + dst = dst[1:] + dst[0] = byte(s.Type) + dst = dst[1:] + dst[0] = byte(s.Code) + dst = dst[1:] + dst[0] = byte(s.Pad) + dst = dst[1:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Info)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Data)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockExtendedErr) UnmarshalBytes(src []byte) { + s.Errno = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Origin = uint8(src[0]) + src = src[1:] + s.Type = uint8(src[0]) + src = src[1:] + s.Code = uint8(src[0]) + src = src[1:] + s.Pad = uint8(src[0]) + src = src[1:] + s.Info = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Data = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockExtendedErr) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockExtendedErr) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockExtendedErr) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockExtendedErr) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockExtendedErr) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockExtendedErr) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockExtendedErr) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FOwnerEx) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FOwnerEx) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Type)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.PID)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FOwnerEx) UnmarshalBytes(src []byte) { + f.Type = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.PID = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FOwnerEx) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FOwnerEx) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FOwnerEx) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FOwnerEx) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FOwnerEx) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FOwnerEx) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FOwnerEx) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *Flock) SizeBytes() int { + return 24 + + 1*4 + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *Flock) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(f.Type)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(f.Whence)) + dst = dst[2:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Start)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Len)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.PID)) + dst = dst[4:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *Flock) UnmarshalBytes(src []byte) { + f.Type = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + f.Whence = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: ~ copy([4]byte(f._), src[:sizeof(byte)*4]) + src = src[1*(4):] + f.Start = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Len = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.PID = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: ~ copy([4]byte(f._), src[:sizeof(byte)*4]) + src = src[1*(4):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *Flock) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *Flock) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *Flock) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *Flock) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *Flock) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *Flock) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *Flock) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Statx) SizeBytes() int { + return 80 + + (*StatxTimestamp)(nil).SizeBytes() + + (*StatxTimestamp)(nil).SizeBytes() + + (*StatxTimestamp)(nil).SizeBytes() + + (*StatxTimestamp)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Statx) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Mask)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Blksize)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Attributes)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Nlink)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.GID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Mode)) + dst = dst[2:] + // Padding: dst[:sizeof(uint16)] ~= uint16(0) + dst = dst[2:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Ino)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.AttributesMask)) + dst = dst[8:] + s.Atime.MarshalBytes(dst[:s.Atime.SizeBytes()]) + dst = dst[s.Atime.SizeBytes():] + s.Btime.MarshalBytes(dst[:s.Btime.SizeBytes()]) + dst = dst[s.Btime.SizeBytes():] + s.Ctime.MarshalBytes(dst[:s.Ctime.SizeBytes()]) + dst = dst[s.Ctime.SizeBytes():] + s.Mtime.MarshalBytes(dst[:s.Mtime.SizeBytes()]) + dst = dst[s.Mtime.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.RdevMajor)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.RdevMinor)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.DevMajor)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.DevMinor)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Statx) UnmarshalBytes(src []byte) { + s.Mask = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Blksize = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Attributes = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Nlink = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Mode = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: var _ uint16 ~= src[:sizeof(uint16)] + src = src[2:] + s.Ino = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Size = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blocks = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.AttributesMask = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Atime.UnmarshalBytes(src[:s.Atime.SizeBytes()]) + src = src[s.Atime.SizeBytes():] + s.Btime.UnmarshalBytes(src[:s.Btime.SizeBytes()]) + src = src[s.Btime.SizeBytes():] + s.Ctime.UnmarshalBytes(src[:s.Ctime.SizeBytes()]) + src = src[s.Ctime.SizeBytes():] + s.Mtime.UnmarshalBytes(src[:s.Mtime.SizeBytes()]) + src = src[s.Mtime.SizeBytes():] + s.RdevMajor = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.RdevMinor = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.DevMajor = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.DevMinor = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Statx) Packed() bool { + return s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Statx) MarshalUnsafe(dst []byte) { + if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type Statx doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Statx) UnmarshalUnsafe(src []byte) { + if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type Statx doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Statx) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Statx) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Statx) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + // Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Statx) WriteTo(writer io.Writer) (int64, error) { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Statfs) SizeBytes() int { + return 80 + + 4*2 + + 8*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Statfs) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Type)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.BlockSize)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.BlocksFree)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.BlocksAvailable)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Files)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.FilesFree)) + dst = dst[8:] + for idx := 0; idx < 2; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.FSID[idx])) + dst = dst[4:] + } + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.NameLength)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.FragmentSize)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Flags)) + dst = dst[8:] + for idx := 0; idx < 4; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Spare[idx])) + dst = dst[8:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Statfs) UnmarshalBytes(src []byte) { + s.Type = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BlockSize = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blocks = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BlocksFree = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BlocksAvailable = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Files = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FilesFree = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 2; idx++ { + s.FSID[idx] = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + s.NameLength = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FragmentSize = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Flags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 4; idx++ { + s.Spare[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Statfs) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Statfs) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Statfs) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Statfs) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Statfs) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Statfs) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Statfs) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEAttr) SizeBytes() int { + return 88 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEAttr) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Ino)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Size)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Blocks)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Atime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Mtime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Ctime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.AtimeNsec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.MtimeNsec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.CtimeNsec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Mode)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Nlink)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.GID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Rdev)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.BlkSize)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEAttr) UnmarshalBytes(src []byte) { + f.Ino = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Size = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Blocks = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Atime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Mtime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Ctime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.AtimeNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.MtimeNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.CtimeNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Nlink = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Rdev = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.BlkSize = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEAttr) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEAttr) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEAttr) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEAttr) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEAttr) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEAttr) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEAttr) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSECreateMeta) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSECreateMeta) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Mode)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Umask)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSECreateMeta) UnmarshalBytes(src []byte) { + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Umask = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSECreateMeta) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSECreateMeta) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSECreateMeta) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSECreateMeta) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSECreateMeta) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSECreateMeta) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSECreateMeta) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEDirentMeta) SizeBytes() int { + return 24 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEDirentMeta) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Ino)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Off)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.NameLen)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Type)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEDirentMeta) UnmarshalBytes(src []byte) { + f.Ino = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Off = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.NameLen = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Type = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEDirentMeta) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEDirentMeta) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEDirentMeta) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEDirentMeta) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEDirentMeta) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEDirentMeta) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEDirentMeta) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEEntryOut) SizeBytes() int { + return 40 + + (*FUSEAttr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEEntryOut) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.NodeID)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Generation)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.EntryValid)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.AttrValid)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.EntryValidNSec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.AttrValidNSec)) + dst = dst[4:] + f.Attr.MarshalBytes(dst[:f.Attr.SizeBytes()]) + dst = dst[f.Attr.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEEntryOut) UnmarshalBytes(src []byte) { + f.NodeID = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Generation = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.EntryValid = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.AttrValid = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.EntryValidNSec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.AttrValidNSec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Attr.UnmarshalBytes(src[:f.Attr.SizeBytes()]) + src = src[f.Attr.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEEntryOut) Packed() bool { + return f.Attr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEEntryOut) MarshalUnsafe(dst []byte) { + if f.Attr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) + } else { + // Type FUSEEntryOut doesn't have a packed layout in memory, fallback to MarshalBytes. + f.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEEntryOut) UnmarshalUnsafe(src []byte) { + if f.Attr.Packed() { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) + } else { + // Type FUSEEntryOut doesn't have a packed layout in memory, fallback to UnmarshalBytes. + f.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEEntryOut) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !f.Attr.Packed() { + // Type FUSEEntryOut doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + f.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEEntryOut) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEEntryOut) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !f.Attr.Packed() { + // Type FUSEEntryOut doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + f.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEEntryOut) WriteTo(writer io.Writer) (int64, error) { + if !f.Attr.Packed() { + // Type FUSEEntryOut doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, f.SizeBytes()) + f.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEGetAttrIn) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEGetAttrIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.GetAttrFlags)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Fh)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEGetAttrIn) UnmarshalBytes(src []byte) { + f.GetAttrFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + f.Fh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEGetAttrIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEGetAttrIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEGetAttrIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEGetAttrIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEGetAttrIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEGetAttrIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEGetAttrIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEGetAttrOut) SizeBytes() int { + return 16 + + (*FUSEAttr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEGetAttrOut) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.AttrValid)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.AttrValidNsec)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + f.Attr.MarshalBytes(dst[:f.Attr.SizeBytes()]) + dst = dst[f.Attr.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEGetAttrOut) UnmarshalBytes(src []byte) { + f.AttrValid = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.AttrValidNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + f.Attr.UnmarshalBytes(src[:f.Attr.SizeBytes()]) + src = src[f.Attr.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEGetAttrOut) Packed() bool { + return f.Attr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEGetAttrOut) MarshalUnsafe(dst []byte) { + if f.Attr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) + } else { + // Type FUSEGetAttrOut doesn't have a packed layout in memory, fallback to MarshalBytes. + f.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEGetAttrOut) UnmarshalUnsafe(src []byte) { + if f.Attr.Packed() { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) + } else { + // Type FUSEGetAttrOut doesn't have a packed layout in memory, fallback to UnmarshalBytes. + f.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEGetAttrOut) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !f.Attr.Packed() { + // Type FUSEGetAttrOut doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + f.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEGetAttrOut) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEGetAttrOut) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !f.Attr.Packed() { + // Type FUSEGetAttrOut doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + f.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEGetAttrOut) WriteTo(writer io.Writer) (int64, error) { + if !f.Attr.Packed() { + // Type FUSEGetAttrOut doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, f.SizeBytes()) + f.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEHeaderIn) SizeBytes() int { + return 28 + + (*FUSEOpcode)(nil).SizeBytes() + + (*FUSEOpID)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEHeaderIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Len)) + dst = dst[4:] + f.Opcode.MarshalBytes(dst[:f.Opcode.SizeBytes()]) + dst = dst[f.Opcode.SizeBytes():] + f.Unique.MarshalBytes(dst[:f.Unique.SizeBytes()]) + dst = dst[f.Unique.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.NodeID)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.GID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.PID)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEHeaderIn) UnmarshalBytes(src []byte) { + f.Len = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Opcode.UnmarshalBytes(src[:f.Opcode.SizeBytes()]) + src = src[f.Opcode.SizeBytes():] + f.Unique.UnmarshalBytes(src[:f.Unique.SizeBytes()]) + src = src[f.Unique.SizeBytes():] + f.NodeID = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.PID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEHeaderIn) Packed() bool { + return f.Opcode.Packed() && f.Unique.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEHeaderIn) MarshalUnsafe(dst []byte) { + if f.Opcode.Packed() && f.Unique.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) + } else { + // Type FUSEHeaderIn doesn't have a packed layout in memory, fallback to MarshalBytes. + f.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEHeaderIn) UnmarshalUnsafe(src []byte) { + if f.Opcode.Packed() && f.Unique.Packed() { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) + } else { + // Type FUSEHeaderIn doesn't have a packed layout in memory, fallback to UnmarshalBytes. + f.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEHeaderIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !f.Opcode.Packed() && f.Unique.Packed() { + // Type FUSEHeaderIn doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + f.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEHeaderIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEHeaderIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !f.Opcode.Packed() && f.Unique.Packed() { + // Type FUSEHeaderIn doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + f.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEHeaderIn) WriteTo(writer io.Writer) (int64, error) { + if !f.Opcode.Packed() && f.Unique.Packed() { + // Type FUSEHeaderIn doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, f.SizeBytes()) + f.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEHeaderOut) SizeBytes() int { + return 8 + + (*FUSEOpID)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEHeaderOut) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Len)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Error)) + dst = dst[4:] + f.Unique.MarshalBytes(dst[:f.Unique.SizeBytes()]) + dst = dst[f.Unique.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEHeaderOut) UnmarshalBytes(src []byte) { + f.Len = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Error = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Unique.UnmarshalBytes(src[:f.Unique.SizeBytes()]) + src = src[f.Unique.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEHeaderOut) Packed() bool { + return f.Unique.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEHeaderOut) MarshalUnsafe(dst []byte) { + if f.Unique.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) + } else { + // Type FUSEHeaderOut doesn't have a packed layout in memory, fallback to MarshalBytes. + f.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEHeaderOut) UnmarshalUnsafe(src []byte) { + if f.Unique.Packed() { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) + } else { + // Type FUSEHeaderOut doesn't have a packed layout in memory, fallback to UnmarshalBytes. + f.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEHeaderOut) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !f.Unique.Packed() { + // Type FUSEHeaderOut doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + f.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEHeaderOut) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEHeaderOut) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !f.Unique.Packed() { + // Type FUSEHeaderOut doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + f.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEHeaderOut) WriteTo(writer io.Writer) (int64, error) { + if !f.Unique.Packed() { + // Type FUSEHeaderOut doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, f.SizeBytes()) + f.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEInitIn) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEInitIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Major)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Minor)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.MaxReadahead)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEInitIn) UnmarshalBytes(src []byte) { + f.Major = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Minor = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.MaxReadahead = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEInitIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEInitIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEInitIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEInitIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEInitIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEInitIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEInitIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEInitOut) SizeBytes() int { + return 32 + + 4*8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEInitOut) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Major)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Minor)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.MaxReadahead)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(f.MaxBackground)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(f.CongestionThreshold)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.MaxWrite)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.TimeGran)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(f.MaxPages)) + dst = dst[2:] + // Padding: dst[:sizeof(uint16)] ~= uint16(0) + dst = dst[2:] + // Padding: dst[:sizeof(uint32)*8] ~= [8]uint32{0} + dst = dst[4*(8):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEInitOut) UnmarshalBytes(src []byte) { + f.Major = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Minor = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.MaxReadahead = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.MaxBackground = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + f.CongestionThreshold = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + f.MaxWrite = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.TimeGran = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: var _ uint16 ~= src[:sizeof(uint16)] + src = src[2:] + // Padding: ~ copy([8]uint32(f._), src[:sizeof(uint32)*8]) + src = src[4*(8):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEInitOut) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEInitOut) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEInitOut) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEInitOut) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEInitOut) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEInitOut) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEInitOut) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEMkdirMeta) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEMkdirMeta) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Mode)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Umask)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEMkdirMeta) UnmarshalBytes(src []byte) { + f.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Umask = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEMkdirMeta) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEMkdirMeta) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEMkdirMeta) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEMkdirMeta) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEMkdirMeta) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEMkdirMeta) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEMkdirMeta) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEMknodMeta) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEMknodMeta) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Mode)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Rdev)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Umask)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEMknodMeta) UnmarshalBytes(src []byte) { + f.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Rdev = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Umask = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEMknodMeta) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEMknodMeta) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEMknodMeta) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEMknodMeta) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEMknodMeta) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEMknodMeta) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEMknodMeta) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (f *FUSEOpID) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEOpID) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(*f)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEOpID) UnmarshalBytes(src []byte) { + *f = FUSEOpID(uint64(hostarch.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEOpID) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEOpID) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEOpID) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEOpID) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEOpID) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEOpID) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEOpID) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (f *FUSEOpcode) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEOpcode) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*f)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEOpcode) UnmarshalBytes(src []byte) { + *f = FUSEOpcode(uint32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEOpcode) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEOpcode) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEOpcode) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEOpcode) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEOpcode) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEOpcode) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEOpcode) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEOpenIn) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEOpenIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEOpenIn) UnmarshalBytes(src []byte) { + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEOpenIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEOpenIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEOpenIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEOpenIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEOpenIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEOpenIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEOpenIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEOpenOut) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEOpenOut) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Fh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.OpenFlag)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEOpenOut) UnmarshalBytes(src []byte) { + f.Fh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.OpenFlag = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEOpenOut) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEOpenOut) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEOpenOut) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEOpenOut) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEOpenOut) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEOpenOut) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEOpenOut) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEReadIn) SizeBytes() int { + return 40 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEReadIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Fh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Offset)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Size)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.ReadFlags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.LockOwner)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEReadIn) UnmarshalBytes(src []byte) { + f.Fh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Offset = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.ReadFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.LockOwner = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEReadIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEReadIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEReadIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEReadIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEReadIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEReadIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEReadIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEReleaseIn) SizeBytes() int { + return 24 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEReleaseIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Fh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.ReleaseFlags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.LockOwner)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEReleaseIn) UnmarshalBytes(src []byte) { + f.Fh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.ReleaseFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.LockOwner = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEReleaseIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEReleaseIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEReleaseIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEReleaseIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEReleaseIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEReleaseIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEReleaseIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSESetAttrIn) SizeBytes() int { + return 88 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSESetAttrIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Valid)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Fh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Size)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.LockOwner)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Atime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Mtime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Ctime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.AtimeNsec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.MtimeNsec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.CtimeNsec)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Mode)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.GID)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSESetAttrIn) UnmarshalBytes(src []byte) { + f.Valid = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + f.Fh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Size = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.LockOwner = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Atime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Mtime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Ctime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.AtimeNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.MtimeNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.CtimeNsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + f.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSESetAttrIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSESetAttrIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSESetAttrIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSESetAttrIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSESetAttrIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSESetAttrIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSESetAttrIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEWriteIn) SizeBytes() int { + return 40 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEWriteIn) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Fh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Offset)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Size)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.WriteFlags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.LockOwner)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Flags)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEWriteIn) UnmarshalBytes(src []byte) { + f.Fh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Offset = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.WriteFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.LockOwner = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + f.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEWriteIn) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEWriteIn) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEWriteIn) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEWriteIn) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEWriteIn) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEWriteIn) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEWriteIn) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FUSEWriteOut) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FUSEWriteOut) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Size)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FUSEWriteOut) UnmarshalBytes(src []byte) { + f.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FUSEWriteOut) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FUSEWriteOut) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FUSEWriteOut) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FUSEWriteOut) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FUSEWriteOut) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FUSEWriteOut) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FUSEWriteOut) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RobustListHead) SizeBytes() int { + return 24 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RobustListHead) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.List)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.FutexOffset)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.ListOpPending)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RobustListHead) UnmarshalBytes(src []byte) { + r.List = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.FutexOffset = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.ListOpPending = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (r *RobustListHead) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (r *RobustListHead) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(r), uintptr(r.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (r *RobustListHead) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(r), unsafe.Pointer(&src[0]), uintptr(r.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (r *RobustListHead) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (r *RobustListHead) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return r.CopyOutN(cc, addr, r.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (r *RobustListHead) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (r *RobustListHead) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (d *DigestMetadata) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (d *DigestMetadata) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(d.DigestAlgorithm)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(d.DigestSize)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (d *DigestMetadata) UnmarshalBytes(src []byte) { + d.DigestAlgorithm = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + d.DigestSize = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (d *DigestMetadata) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (d *DigestMetadata) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(d), uintptr(d.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (d *DigestMetadata) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(d), unsafe.Pointer(&src[0]), uintptr(d.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (d *DigestMetadata) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(d))) + hdr.Len = d.SizeBytes() + hdr.Cap = d.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that d + // must live until the use above. + runtime.KeepAlive(d) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (d *DigestMetadata) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return d.CopyOutN(cc, addr, d.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (d *DigestMetadata) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(d))) + hdr.Len = d.SizeBytes() + hdr.Cap = d.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that d + // must live until the use above. + runtime.KeepAlive(d) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (d *DigestMetadata) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(d))) + hdr.Len = d.SizeBytes() + hdr.Cap = d.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that d + // must live until the use above. + runtime.KeepAlive(d) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPCPerm) SizeBytes() int { + return 48 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPCPerm) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Key)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.GID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.CUID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.CGID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.Mode)) + dst = dst[2:] + // Padding: dst[:sizeof(uint16)] ~= uint16(0) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.Seq)) + dst = dst[2:] + // Padding: dst[:sizeof(uint16)] ~= uint16(0) + dst = dst[2:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.unused1)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.unused2)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPCPerm) UnmarshalBytes(src []byte) { + i.Key = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.CUID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.CGID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Mode = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: var _ uint16 ~= src[:sizeof(uint16)] + src = src[2:] + i.Seq = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: var _ uint16 ~= src[:sizeof(uint16)] + src = src[2:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + i.unused1 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + i.unused2 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPCPerm) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPCPerm) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPCPerm) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPCPerm) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPCPerm) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPCPerm) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPCPerm) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Sysinfo) SizeBytes() int { + return 78 + + 8*3 + + 1*6 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Sysinfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Uptime)) + dst = dst[8:] + for idx := 0; idx < 3; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Loads[idx])) + dst = dst[8:] + } + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.TotalRAM)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.FreeRAM)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.SharedRAM)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.BufferRAM)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.TotalSwap)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.FreeSwap)) + dst = dst[8:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Procs)) + dst = dst[2:] + // Padding: dst[:sizeof(byte)*6] ~= [6]byte{0} + dst = dst[1*(6):] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.TotalHigh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.FreeHigh)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Unit)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Sysinfo) UnmarshalBytes(src []byte) { + s.Uptime = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 3; idx++ { + s.Loads[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } + s.TotalRAM = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FreeRAM = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.SharedRAM = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.BufferRAM = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.TotalSwap = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FreeSwap = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Procs = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: ~ copy([6]byte(s._), src[:sizeof(byte)*6]) + src = src[1*(6):] + s.TotalHigh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.FreeHigh = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Unit = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Sysinfo) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Sysinfo) MarshalUnsafe(dst []byte) { + // Type Sysinfo doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Sysinfo) UnmarshalUnsafe(src []byte) { + // Type Sysinfo doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Sysinfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type Sysinfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Sysinfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Sysinfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type Sysinfo doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Sysinfo) WriteTo(writer io.Writer) (int64, error) { + // Type Sysinfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (n *NumaPolicy) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NumaPolicy) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*n)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NumaPolicy) UnmarshalBytes(src []byte) { + *n = NumaPolicy(int32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NumaPolicy) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NumaPolicy) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NumaPolicy) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NumaPolicy) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NumaPolicy) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NumaPolicy) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NumaPolicy) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (b *MsgBuf) Packed() bool { + // Type MsgBuf is dynamic so it might have slice/string headers. Hence, it is not packed. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *MsgBuf) MarshalUnsafe(dst []byte) { + // Type MsgBuf doesn't have a packed layout in memory, fallback to MarshalBytes. + b.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *MsgBuf) UnmarshalUnsafe(src []byte) { + // Type MsgBuf doesn't have a packed layout in memory, fallback to UnmarshalBytes. + b.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (b *MsgBuf) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type MsgBuf doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(b.SizeBytes()) // escapes: okay. + b.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (b *MsgBuf) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return b.CopyOutN(cc, addr, b.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (b *MsgBuf) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type MsgBuf doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(b.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + b.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *MsgBuf) WriteTo(writer io.Writer) (int64, error) { + // Type MsgBuf doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, b.SizeBytes()) + b.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MsgInfo) SizeBytes() int { + return 30 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MsgInfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgPool)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMap)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMax)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMnb)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMni)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgSsz)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgTql)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(m.MsgSeg)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MsgInfo) UnmarshalBytes(src []byte) { + m.MsgPool = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgMap = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgMax = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgMnb = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgMni = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgSsz = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgTql = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgSeg = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (m *MsgInfo) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (m *MsgInfo) MarshalUnsafe(dst []byte) { + // Type MsgInfo doesn't have a packed layout in memory, fallback to MarshalBytes. + m.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (m *MsgInfo) UnmarshalUnsafe(src []byte) { + // Type MsgInfo doesn't have a packed layout in memory, fallback to UnmarshalBytes. + m.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (m *MsgInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type MsgInfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + m.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (m *MsgInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return m.CopyOutN(cc, addr, m.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (m *MsgInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type MsgInfo doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + m.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (m *MsgInfo) WriteTo(writer io.Writer) (int64, error) { + // Type MsgInfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, m.SizeBytes()) + m.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MsqidDS) SizeBytes() int { + return 48 + + (*IPCPerm)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MsqidDS) MarshalBytes(dst []byte) { + m.MsgPerm.MarshalBytes(dst[:m.MsgPerm.SizeBytes()]) + dst = dst[m.MsgPerm.SizeBytes():] + m.MsgStime.MarshalBytes(dst[:m.MsgStime.SizeBytes()]) + dst = dst[m.MsgStime.SizeBytes():] + m.MsgRtime.MarshalBytes(dst[:m.MsgRtime.SizeBytes()]) + dst = dst[m.MsgRtime.SizeBytes():] + m.MsgCtime.MarshalBytes(dst[:m.MsgCtime.SizeBytes()]) + dst = dst[m.MsgCtime.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.MsgCbytes)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.MsgQnum)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.MsgQbytes)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgLspid)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgLrpid)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.unused4)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.unused5)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MsqidDS) UnmarshalBytes(src []byte) { + m.MsgPerm.UnmarshalBytes(src[:m.MsgPerm.SizeBytes()]) + src = src[m.MsgPerm.SizeBytes():] + m.MsgStime.UnmarshalBytes(src[:m.MsgStime.SizeBytes()]) + src = src[m.MsgStime.SizeBytes():] + m.MsgRtime.UnmarshalBytes(src[:m.MsgRtime.SizeBytes()]) + src = src[m.MsgRtime.SizeBytes():] + m.MsgCtime.UnmarshalBytes(src[:m.MsgCtime.SizeBytes()]) + src = src[m.MsgCtime.SizeBytes():] + m.MsgCbytes = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.MsgQnum = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.MsgQbytes = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.MsgLspid = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.MsgLrpid = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + m.unused4 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.unused5 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (m *MsqidDS) Packed() bool { + return m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (m *MsqidDS) MarshalUnsafe(dst []byte) { + if m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(m), uintptr(m.SizeBytes())) + } else { + // Type MsqidDS doesn't have a packed layout in memory, fallback to MarshalBytes. + m.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (m *MsqidDS) UnmarshalUnsafe(src []byte) { + if m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() { + gohacks.Memmove(unsafe.Pointer(m), unsafe.Pointer(&src[0]), uintptr(m.SizeBytes())) + } else { + // Type MsqidDS doesn't have a packed layout in memory, fallback to UnmarshalBytes. + m.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (m *MsqidDS) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() { + // Type MsqidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + m.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (m *MsqidDS) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return m.CopyOutN(cc, addr, m.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (m *MsqidDS) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() { + // Type MsqidDS doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + m.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (m *MsqidDS) WriteTo(writer io.Writer) (int64, error) { + if !m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() { + // Type MsqidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, m.SizeBytes()) + m.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IFConf) SizeBytes() int { + return 12 + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IFConf) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Len)) + dst = dst[4:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Ptr)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IFConf) UnmarshalBytes(src []byte) { + i.Len = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: ~ copy([4]byte(i._), src[:sizeof(byte)*4]) + src = src[1*(4):] + i.Ptr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IFConf) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IFConf) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IFConf) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IFConf) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IFConf) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IFConf) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IFConf) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ifr *IFReq) SizeBytes() int { + return 0 + + 1*IFNAMSIZ + + 1*24 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ifr *IFReq) MarshalBytes(dst []byte) { + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(ifr.IFName[idx]) + dst = dst[1:] + } + for idx := 0; idx < 24; idx++ { + dst[0] = byte(ifr.Data[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ifr *IFReq) UnmarshalBytes(src []byte) { + for idx := 0; idx < IFNAMSIZ; idx++ { + ifr.IFName[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < 24; idx++ { + ifr.Data[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (ifr *IFReq) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ifr *IFReq) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(ifr), uintptr(ifr.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ifr *IFReq) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(ifr), unsafe.Pointer(&src[0]), uintptr(ifr.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (ifr *IFReq) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(ifr))) + hdr.Len = ifr.SizeBytes() + hdr.Cap = ifr.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that ifr + // must live until the use above. + runtime.KeepAlive(ifr) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (ifr *IFReq) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return ifr.CopyOutN(cc, addr, ifr.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (ifr *IFReq) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(ifr))) + hdr.Len = ifr.SizeBytes() + hdr.Cap = ifr.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that ifr + // must live until the use above. + runtime.KeepAlive(ifr) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ifr *IFReq) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(ifr))) + hdr.Len = ifr.SizeBytes() + hdr.Cap = ifr.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that ifr + // must live until the use above. + runtime.KeepAlive(ifr) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (en *ErrorName) SizeBytes() int { + return 1 * XT_FUNCTION_MAXNAMELEN +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (en *ErrorName) MarshalBytes(dst []byte) { + for idx := 0; idx < XT_FUNCTION_MAXNAMELEN; idx++ { + dst[0] = byte(en[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (en *ErrorName) UnmarshalBytes(src []byte) { + for idx := 0; idx < XT_FUNCTION_MAXNAMELEN; idx++ { + en[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (en *ErrorName) Packed() bool { + // Array newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (en *ErrorName) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&en[0]), uintptr(en.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (en *ErrorName) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(en), unsafe.Pointer(&src[0]), uintptr(en.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (en *ErrorName) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(en))) + hdr.Len = en.SizeBytes() + hdr.Cap = en.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that en + // must live until the use above. + runtime.KeepAlive(en) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (en *ErrorName) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return en.CopyOutN(cc, addr, en.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (en *ErrorName) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(en))) + hdr.Len = en.SizeBytes() + hdr.Cap = en.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that en + // must live until the use above. + runtime.KeepAlive(en) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (en *ErrorName) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(en))) + hdr.Len = en.SizeBytes() + hdr.Cap = en.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that en + // must live until the use above. + runtime.KeepAlive(en) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (en *ExtensionName) SizeBytes() int { + return 1 * XT_EXTENSION_MAXNAMELEN +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (en *ExtensionName) MarshalBytes(dst []byte) { + for idx := 0; idx < XT_EXTENSION_MAXNAMELEN; idx++ { + dst[0] = byte(en[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (en *ExtensionName) UnmarshalBytes(src []byte) { + for idx := 0; idx < XT_EXTENSION_MAXNAMELEN; idx++ { + en[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (en *ExtensionName) Packed() bool { + // Array newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (en *ExtensionName) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&en[0]), uintptr(en.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (en *ExtensionName) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(en), unsafe.Pointer(&src[0]), uintptr(en.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (en *ExtensionName) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(en))) + hdr.Len = en.SizeBytes() + hdr.Cap = en.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that en + // must live until the use above. + runtime.KeepAlive(en) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (en *ExtensionName) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return en.CopyOutN(cc, addr, en.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (en *ExtensionName) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(en))) + hdr.Len = en.SizeBytes() + hdr.Cap = en.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that en + // must live until the use above. + runtime.KeepAlive(en) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (en *ExtensionName) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(en))) + hdr.Len = en.SizeBytes() + hdr.Cap = en.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that en + // must live until the use above. + runtime.KeepAlive(en) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPTEntry) SizeBytes() int { + return 12 + + (*IPTIP)(nil).SizeBytes() + + (*XTCounters)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPTEntry) MarshalBytes(dst []byte) { + i.IP.MarshalBytes(dst[:i.IP.SizeBytes()]) + dst = dst[i.IP.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NFCache)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.TargetOffset)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.NextOffset)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Comeback)) + dst = dst[4:] + i.Counters.MarshalBytes(dst[:i.Counters.SizeBytes()]) + dst = dst[i.Counters.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPTEntry) UnmarshalBytes(src []byte) { + i.IP.UnmarshalBytes(src[:i.IP.SizeBytes()]) + src = src[i.IP.SizeBytes():] + i.NFCache = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.TargetOffset = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.NextOffset = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.Comeback = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Counters.UnmarshalBytes(src[:i.Counters.SizeBytes()]) + src = src[i.Counters.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPTEntry) Packed() bool { + return i.Counters.Packed() && i.IP.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPTEntry) MarshalUnsafe(dst []byte) { + if i.Counters.Packed() && i.IP.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IPTEntry doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPTEntry) UnmarshalUnsafe(src []byte) { + if i.Counters.Packed() && i.IP.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IPTEntry doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPTEntry) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Counters.Packed() && i.IP.Packed() { + // Type IPTEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPTEntry) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPTEntry) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Counters.Packed() && i.IP.Packed() { + // Type IPTEntry doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPTEntry) WriteTo(writer io.Writer) (int64, error) { + if !i.Counters.Packed() && i.IP.Packed() { + // Type IPTEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPTGetEntries) SizeBytes() int { + return 4 + + (*TableName)(nil).SizeBytes() + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPTGetEntries) MarshalBytes(dst []byte) { + i.Name.MarshalBytes(dst[:i.Name.SizeBytes()]) + dst = dst[i.Name.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Size)) + dst = dst[4:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPTGetEntries) UnmarshalBytes(src []byte) { + i.Name.UnmarshalBytes(src[:i.Name.SizeBytes()]) + src = src[i.Name.SizeBytes():] + i.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: ~ copy([4]byte(i._), src[:sizeof(byte)*4]) + src = src[1*(4):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPTGetEntries) Packed() bool { + return i.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPTGetEntries) MarshalUnsafe(dst []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IPTGetEntries doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPTGetEntries) UnmarshalUnsafe(src []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IPTGetEntries doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPTGetEntries) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Name.Packed() { + // Type IPTGetEntries doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPTGetEntries) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPTGetEntries) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Name.Packed() { + // Type IPTGetEntries doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPTGetEntries) WriteTo(writer io.Writer) (int64, error) { + if !i.Name.Packed() { + // Type IPTGetEntries doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPTGetinfo) SizeBytes() int { + return 12 + + (*TableName)(nil).SizeBytes() + + 4*NF_INET_NUMHOOKS + + 4*NF_INET_NUMHOOKS +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPTGetinfo) MarshalBytes(dst []byte) { + i.Name.MarshalBytes(dst[:i.Name.SizeBytes()]) + dst = dst[i.Name.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.ValidHooks)) + dst = dst[4:] + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.HookEntry[idx])) + dst = dst[4:] + } + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Underflow[idx])) + dst = dst[4:] + } + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NumEntries)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Size)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPTGetinfo) UnmarshalBytes(src []byte) { + i.Name.UnmarshalBytes(src[:i.Name.SizeBytes()]) + src = src[i.Name.SizeBytes():] + i.ValidHooks = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + i.HookEntry[idx] = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + i.Underflow[idx] = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + i.NumEntries = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPTGetinfo) Packed() bool { + return i.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPTGetinfo) MarshalUnsafe(dst []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IPTGetinfo doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPTGetinfo) UnmarshalUnsafe(src []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IPTGetinfo doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPTGetinfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Name.Packed() { + // Type IPTGetinfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPTGetinfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPTGetinfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Name.Packed() { + // Type IPTGetinfo doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPTGetinfo) WriteTo(writer io.Writer) (int64, error) { + if !i.Name.Packed() { + // Type IPTGetinfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPTIP) SizeBytes() int { + return 4 + + (*InetAddr)(nil).SizeBytes() + + (*InetAddr)(nil).SizeBytes() + + (*InetAddr)(nil).SizeBytes() + + (*InetAddr)(nil).SizeBytes() + + 1*IFNAMSIZ + + 1*IFNAMSIZ + + 1*IFNAMSIZ + + 1*IFNAMSIZ +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPTIP) MarshalBytes(dst []byte) { + i.Src.MarshalBytes(dst[:i.Src.SizeBytes()]) + dst = dst[i.Src.SizeBytes():] + i.Dst.MarshalBytes(dst[:i.Dst.SizeBytes()]) + dst = dst[i.Dst.SizeBytes():] + i.SrcMask.MarshalBytes(dst[:i.SrcMask.SizeBytes()]) + dst = dst[i.SrcMask.SizeBytes():] + i.DstMask.MarshalBytes(dst[:i.DstMask.SizeBytes()]) + dst = dst[i.DstMask.SizeBytes():] + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.InputInterface[idx]) + dst = dst[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.OutputInterface[idx]) + dst = dst[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.InputInterfaceMask[idx]) + dst = dst[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.OutputInterfaceMask[idx]) + dst = dst[1:] + } + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.Protocol)) + dst = dst[2:] + dst[0] = byte(i.Flags) + dst = dst[1:] + dst[0] = byte(i.InverseFlags) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPTIP) UnmarshalBytes(src []byte) { + i.Src.UnmarshalBytes(src[:i.Src.SizeBytes()]) + src = src[i.Src.SizeBytes():] + i.Dst.UnmarshalBytes(src[:i.Dst.SizeBytes()]) + src = src[i.Dst.SizeBytes():] + i.SrcMask.UnmarshalBytes(src[:i.SrcMask.SizeBytes()]) + src = src[i.SrcMask.SizeBytes():] + i.DstMask.UnmarshalBytes(src[:i.DstMask.SizeBytes()]) + src = src[i.DstMask.SizeBytes():] + for idx := 0; idx < IFNAMSIZ; idx++ { + i.InputInterface[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + i.OutputInterface[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + i.InputInterfaceMask[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + i.OutputInterfaceMask[idx] = src[0] + src = src[1:] + } + i.Protocol = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.Flags = uint8(src[0]) + src = src[1:] + i.InverseFlags = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPTIP) Packed() bool { + return i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPTIP) MarshalUnsafe(dst []byte) { + if i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IPTIP doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPTIP) UnmarshalUnsafe(src []byte) { + if i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IPTIP doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPTIP) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + // Type IPTIP doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPTIP) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPTIP) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + // Type IPTIP doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPTIP) WriteTo(writer io.Writer) (int64, error) { + if !i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + // Type IPTIP doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPTOwnerInfo) SizeBytes() int { + return 18 + + 1*16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPTOwnerInfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.GID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.PID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.SID)) + dst = dst[4:] + for idx := 0; idx < 16; idx++ { + dst[0] = byte(i.Comm[idx]) + dst = dst[1:] + } + dst[0] = byte(i.Match) + dst = dst[1:] + dst[0] = byte(i.Invert) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPTOwnerInfo) UnmarshalBytes(src []byte) { + i.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.PID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.SID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 16; idx++ { + i.Comm[idx] = src[0] + src = src[1:] + } + i.Match = uint8(src[0]) + src = src[1:] + i.Invert = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPTOwnerInfo) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPTOwnerInfo) MarshalUnsafe(dst []byte) { + // Type IPTOwnerInfo doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPTOwnerInfo) UnmarshalUnsafe(src []byte) { + // Type IPTOwnerInfo doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPTOwnerInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type IPTOwnerInfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPTOwnerInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPTOwnerInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type IPTOwnerInfo doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPTOwnerInfo) WriteTo(writer io.Writer) (int64, error) { + // Type IPTOwnerInfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IPTReplace) SizeBytes() int { + return 24 + + (*TableName)(nil).SizeBytes() + + 4*NF_INET_NUMHOOKS + + 4*NF_INET_NUMHOOKS +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IPTReplace) MarshalBytes(dst []byte) { + i.Name.MarshalBytes(dst[:i.Name.SizeBytes()]) + dst = dst[i.Name.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.ValidHooks)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NumEntries)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Size)) + dst = dst[4:] + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.HookEntry[idx])) + dst = dst[4:] + } + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Underflow[idx])) + dst = dst[4:] + } + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NumCounters)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Counters)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IPTReplace) UnmarshalBytes(src []byte) { + i.Name.UnmarshalBytes(src[:i.Name.SizeBytes()]) + src = src[i.Name.SizeBytes():] + i.ValidHooks = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.NumEntries = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + i.HookEntry[idx] = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + i.Underflow[idx] = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + i.NumCounters = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Counters = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IPTReplace) Packed() bool { + return i.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IPTReplace) MarshalUnsafe(dst []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IPTReplace doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IPTReplace) UnmarshalUnsafe(src []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IPTReplace doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IPTReplace) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Name.Packed() { + // Type IPTReplace doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IPTReplace) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IPTReplace) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Name.Packed() { + // Type IPTReplace doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IPTReplace) WriteTo(writer io.Writer) (int64, error) { + if !i.Name.Packed() { + // Type IPTReplace doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (ke *KernelIPTEntry) Packed() bool { + // Type KernelIPTEntry is dynamic so it might have slice/string headers. Hence, it is not packed. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTEntry) MarshalUnsafe(dst []byte) { + // Type KernelIPTEntry doesn't have a packed layout in memory, fallback to MarshalBytes. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTEntry) UnmarshalUnsafe(src []byte) { + // Type KernelIPTEntry doesn't have a packed layout in memory, fallback to UnmarshalBytes. + ke.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (ke *KernelIPTEntry) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type KernelIPTEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + ke.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (ke *KernelIPTEntry) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return ke.CopyOutN(cc, addr, ke.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (ke *KernelIPTEntry) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type KernelIPTEntry doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTEntry) WriteTo(writer io.Writer) (int64, error) { + // Type KernelIPTEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (ke *KernelIPTGetEntries) Packed() bool { + // Type KernelIPTGetEntries is dynamic so it might have slice/string headers. Hence, it is not packed. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fallback to MarshalBytes. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fallback to UnmarshalBytes. + ke.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + ke.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return ke.CopyOutN(cc, addr, ke.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(writer io.Writer) (int64, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *NfNATIPV4MultiRangeCompat) SizeBytes() int { + return 4 + + (*NfNATIPV4Range)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NfNATIPV4MultiRangeCompat) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.RangeSize)) + dst = dst[4:] + n.RangeIPV4.MarshalBytes(dst[:n.RangeIPV4.SizeBytes()]) + dst = dst[n.RangeIPV4.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NfNATIPV4MultiRangeCompat) UnmarshalBytes(src []byte) { + n.RangeSize = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + n.RangeIPV4.UnmarshalBytes(src[:n.RangeIPV4.SizeBytes()]) + src = src[n.RangeIPV4.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NfNATIPV4MultiRangeCompat) Packed() bool { + return n.RangeIPV4.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NfNATIPV4MultiRangeCompat) MarshalUnsafe(dst []byte) { + if n.RangeIPV4.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) + } else { + // Type NfNATIPV4MultiRangeCompat doesn't have a packed layout in memory, fallback to MarshalBytes. + n.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NfNATIPV4MultiRangeCompat) UnmarshalUnsafe(src []byte) { + if n.RangeIPV4.Packed() { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) + } else { + // Type NfNATIPV4MultiRangeCompat doesn't have a packed layout in memory, fallback to UnmarshalBytes. + n.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NfNATIPV4MultiRangeCompat) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !n.RangeIPV4.Packed() { + // Type NfNATIPV4MultiRangeCompat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + n.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NfNATIPV4MultiRangeCompat) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NfNATIPV4MultiRangeCompat) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !n.RangeIPV4.Packed() { + // Type NfNATIPV4MultiRangeCompat doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + n.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NfNATIPV4MultiRangeCompat) WriteTo(writer io.Writer) (int64, error) { + if !n.RangeIPV4.Packed() { + // Type NfNATIPV4MultiRangeCompat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, n.SizeBytes()) + n.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *NfNATIPV4Range) SizeBytes() int { + return 8 + + 1*4 + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NfNATIPV4Range) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.Flags)) + dst = dst[4:] + for idx := 0; idx < 4; idx++ { + dst[0] = byte(n.MinIP[idx]) + dst = dst[1:] + } + for idx := 0; idx < 4; idx++ { + dst[0] = byte(n.MaxIP[idx]) + dst = dst[1:] + } + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.MinPort)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.MaxPort)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NfNATIPV4Range) UnmarshalBytes(src []byte) { + n.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 4; idx++ { + n.MinIP[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < 4; idx++ { + n.MaxIP[idx] = src[0] + src = src[1:] + } + n.MinPort = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + n.MaxPort = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NfNATIPV4Range) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NfNATIPV4Range) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NfNATIPV4Range) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NfNATIPV4Range) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NfNATIPV4Range) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NfNATIPV4Range) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NfNATIPV4Range) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (tn *TableName) SizeBytes() int { + return 1 * XT_TABLE_MAXNAMELEN +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (tn *TableName) MarshalBytes(dst []byte) { + for idx := 0; idx < XT_TABLE_MAXNAMELEN; idx++ { + dst[0] = byte(tn[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (tn *TableName) UnmarshalBytes(src []byte) { + for idx := 0; idx < XT_TABLE_MAXNAMELEN; idx++ { + tn[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (tn *TableName) Packed() bool { + // Array newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (tn *TableName) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&tn[0]), uintptr(tn.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (tn *TableName) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(tn), unsafe.Pointer(&src[0]), uintptr(tn.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (tn *TableName) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tn))) + hdr.Len = tn.SizeBytes() + hdr.Cap = tn.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that tn + // must live until the use above. + runtime.KeepAlive(tn) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (tn *TableName) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return tn.CopyOutN(cc, addr, tn.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (tn *TableName) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tn))) + hdr.Len = tn.SizeBytes() + hdr.Cap = tn.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that tn + // must live until the use above. + runtime.KeepAlive(tn) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (tn *TableName) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tn))) + hdr.Len = tn.SizeBytes() + hdr.Cap = tn.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that tn + // must live until the use above. + runtime.KeepAlive(tn) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTCounters) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTCounters) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(x.Pcnt)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(x.Bcnt)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTCounters) UnmarshalBytes(src []byte) { + x.Pcnt = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + x.Bcnt = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTCounters) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTCounters) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTCounters) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTCounters) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTCounters) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTCounters) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTCounters) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTEntryMatch) SizeBytes() int { + return 3 + + (*ExtensionName)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTEntryMatch) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.MatchSize)) + dst = dst[2:] + x.Name.MarshalBytes(dst[:x.Name.SizeBytes()]) + dst = dst[x.Name.SizeBytes():] + dst[0] = byte(x.Revision) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTEntryMatch) UnmarshalBytes(src []byte) { + x.MatchSize = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.Name.UnmarshalBytes(src[:x.Name.SizeBytes()]) + src = src[x.Name.SizeBytes():] + x.Revision = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTEntryMatch) Packed() bool { + return x.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTEntryMatch) MarshalUnsafe(dst []byte) { + if x.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTEntryMatch doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTEntryMatch) UnmarshalUnsafe(src []byte) { + if x.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTEntryMatch doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTEntryMatch) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.Name.Packed() { + // Type XTEntryMatch doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTEntryMatch) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTEntryMatch) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.Name.Packed() { + // Type XTEntryMatch doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTEntryMatch) WriteTo(writer io.Writer) (int64, error) { + if !x.Name.Packed() { + // Type XTEntryMatch doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTEntryTarget) SizeBytes() int { + return 3 + + (*ExtensionName)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTEntryTarget) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.TargetSize)) + dst = dst[2:] + x.Name.MarshalBytes(dst[:x.Name.SizeBytes()]) + dst = dst[x.Name.SizeBytes():] + dst[0] = byte(x.Revision) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTEntryTarget) UnmarshalBytes(src []byte) { + x.TargetSize = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.Name.UnmarshalBytes(src[:x.Name.SizeBytes()]) + src = src[x.Name.SizeBytes():] + x.Revision = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTEntryTarget) Packed() bool { + return x.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTEntryTarget) MarshalUnsafe(dst []byte) { + if x.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTEntryTarget doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTEntryTarget) UnmarshalUnsafe(src []byte) { + if x.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTEntryTarget doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTEntryTarget) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.Name.Packed() { + // Type XTEntryTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTEntryTarget) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTEntryTarget) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.Name.Packed() { + // Type XTEntryTarget doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTEntryTarget) WriteTo(writer io.Writer) (int64, error) { + if !x.Name.Packed() { + // Type XTEntryTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTErrorTarget) SizeBytes() int { + return 0 + + (*XTEntryTarget)(nil).SizeBytes() + + (*ErrorName)(nil).SizeBytes() + + 1*2 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTErrorTarget) MarshalBytes(dst []byte) { + x.Target.MarshalBytes(dst[:x.Target.SizeBytes()]) + dst = dst[x.Target.SizeBytes():] + x.Name.MarshalBytes(dst[:x.Name.SizeBytes()]) + dst = dst[x.Name.SizeBytes():] + // Padding: dst[:sizeof(byte)*2] ~= [2]byte{0} + dst = dst[1*(2):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTErrorTarget) UnmarshalBytes(src []byte) { + x.Target.UnmarshalBytes(src[:x.Target.SizeBytes()]) + src = src[x.Target.SizeBytes():] + x.Name.UnmarshalBytes(src[:x.Name.SizeBytes()]) + src = src[x.Name.SizeBytes():] + // Padding: ~ copy([2]byte(x._), src[:sizeof(byte)*2]) + src = src[1*(2):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTErrorTarget) Packed() bool { + return x.Name.Packed() && x.Target.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTErrorTarget) MarshalUnsafe(dst []byte) { + if x.Name.Packed() && x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTErrorTarget doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTErrorTarget) UnmarshalUnsafe(src []byte) { + if x.Name.Packed() && x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTErrorTarget doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTErrorTarget) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.Name.Packed() && x.Target.Packed() { + // Type XTErrorTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTErrorTarget) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTErrorTarget) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.Name.Packed() && x.Target.Packed() { + // Type XTErrorTarget doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTErrorTarget) WriteTo(writer io.Writer) (int64, error) { + if !x.Name.Packed() && x.Target.Packed() { + // Type XTErrorTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTGetRevision) SizeBytes() int { + return 1 + + (*ExtensionName)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTGetRevision) MarshalBytes(dst []byte) { + x.Name.MarshalBytes(dst[:x.Name.SizeBytes()]) + dst = dst[x.Name.SizeBytes():] + dst[0] = byte(x.Revision) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTGetRevision) UnmarshalBytes(src []byte) { + x.Name.UnmarshalBytes(src[:x.Name.SizeBytes()]) + src = src[x.Name.SizeBytes():] + x.Revision = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTGetRevision) Packed() bool { + return x.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTGetRevision) MarshalUnsafe(dst []byte) { + if x.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTGetRevision doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTGetRevision) UnmarshalUnsafe(src []byte) { + if x.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTGetRevision doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTGetRevision) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.Name.Packed() { + // Type XTGetRevision doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTGetRevision) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTGetRevision) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.Name.Packed() { + // Type XTGetRevision doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTGetRevision) WriteTo(writer io.Writer) (int64, error) { + if !x.Name.Packed() { + // Type XTGetRevision doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTRedirectTarget) SizeBytes() int { + return 0 + + (*XTEntryTarget)(nil).SizeBytes() + + (*NfNATIPV4MultiRangeCompat)(nil).SizeBytes() + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTRedirectTarget) MarshalBytes(dst []byte) { + x.Target.MarshalBytes(dst[:x.Target.SizeBytes()]) + dst = dst[x.Target.SizeBytes():] + x.NfRange.MarshalBytes(dst[:x.NfRange.SizeBytes()]) + dst = dst[x.NfRange.SizeBytes():] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTRedirectTarget) UnmarshalBytes(src []byte) { + x.Target.UnmarshalBytes(src[:x.Target.SizeBytes()]) + src = src[x.Target.SizeBytes():] + x.NfRange.UnmarshalBytes(src[:x.NfRange.SizeBytes()]) + src = src[x.NfRange.SizeBytes():] + // Padding: ~ copy([4]byte(x._), src[:sizeof(byte)*4]) + src = src[1*(4):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTRedirectTarget) Packed() bool { + return x.NfRange.Packed() && x.Target.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTRedirectTarget) MarshalUnsafe(dst []byte) { + if x.NfRange.Packed() && x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTRedirectTarget doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTRedirectTarget) UnmarshalUnsafe(src []byte) { + if x.NfRange.Packed() && x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTRedirectTarget doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTRedirectTarget) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.NfRange.Packed() && x.Target.Packed() { + // Type XTRedirectTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTRedirectTarget) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTRedirectTarget) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.NfRange.Packed() && x.Target.Packed() { + // Type XTRedirectTarget doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTRedirectTarget) WriteTo(writer io.Writer) (int64, error) { + if !x.NfRange.Packed() && x.Target.Packed() { + // Type XTRedirectTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTSNATTarget) SizeBytes() int { + return 0 + + (*XTEntryTarget)(nil).SizeBytes() + + (*NfNATIPV4MultiRangeCompat)(nil).SizeBytes() + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTSNATTarget) MarshalBytes(dst []byte) { + x.Target.MarshalBytes(dst[:x.Target.SizeBytes()]) + dst = dst[x.Target.SizeBytes():] + x.NfRange.MarshalBytes(dst[:x.NfRange.SizeBytes()]) + dst = dst[x.NfRange.SizeBytes():] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTSNATTarget) UnmarshalBytes(src []byte) { + x.Target.UnmarshalBytes(src[:x.Target.SizeBytes()]) + src = src[x.Target.SizeBytes():] + x.NfRange.UnmarshalBytes(src[:x.NfRange.SizeBytes()]) + src = src[x.NfRange.SizeBytes():] + // Padding: ~ copy([4]byte(x._), src[:sizeof(byte)*4]) + src = src[1*(4):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTSNATTarget) Packed() bool { + return x.NfRange.Packed() && x.Target.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTSNATTarget) MarshalUnsafe(dst []byte) { + if x.NfRange.Packed() && x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTSNATTarget doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTSNATTarget) UnmarshalUnsafe(src []byte) { + if x.NfRange.Packed() && x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTSNATTarget doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTSNATTarget) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.NfRange.Packed() && x.Target.Packed() { + // Type XTSNATTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTSNATTarget) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTSNATTarget) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.NfRange.Packed() && x.Target.Packed() { + // Type XTSNATTarget doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTSNATTarget) WriteTo(writer io.Writer) (int64, error) { + if !x.NfRange.Packed() && x.Target.Packed() { + // Type XTSNATTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTStandardTarget) SizeBytes() int { + return 4 + + (*XTEntryTarget)(nil).SizeBytes() + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTStandardTarget) MarshalBytes(dst []byte) { + x.Target.MarshalBytes(dst[:x.Target.SizeBytes()]) + dst = dst[x.Target.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(x.Verdict)) + dst = dst[4:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTStandardTarget) UnmarshalBytes(src []byte) { + x.Target.UnmarshalBytes(src[:x.Target.SizeBytes()]) + src = src[x.Target.SizeBytes():] + x.Verdict = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: ~ copy([4]byte(x._), src[:sizeof(byte)*4]) + src = src[1*(4):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTStandardTarget) Packed() bool { + return x.Target.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTStandardTarget) MarshalUnsafe(dst []byte) { + if x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) + } else { + // Type XTStandardTarget doesn't have a packed layout in memory, fallback to MarshalBytes. + x.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTStandardTarget) UnmarshalUnsafe(src []byte) { + if x.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) + } else { + // Type XTStandardTarget doesn't have a packed layout in memory, fallback to UnmarshalBytes. + x.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTStandardTarget) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !x.Target.Packed() { + // Type XTStandardTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + x.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTStandardTarget) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTStandardTarget) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !x.Target.Packed() { + // Type XTStandardTarget doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(x.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + x.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTStandardTarget) WriteTo(writer io.Writer) (int64, error) { + if !x.Target.Packed() { + // Type XTStandardTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, x.SizeBytes()) + x.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTTCP) SizeBytes() int { + return 12 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTTCP) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.SourcePortStart)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.SourcePortEnd)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.DestinationPortStart)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.DestinationPortEnd)) + dst = dst[2:] + dst[0] = byte(x.Option) + dst = dst[1:] + dst[0] = byte(x.FlagMask) + dst = dst[1:] + dst[0] = byte(x.FlagCompare) + dst = dst[1:] + dst[0] = byte(x.InverseFlags) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTTCP) UnmarshalBytes(src []byte) { + x.SourcePortStart = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.SourcePortEnd = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.DestinationPortStart = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.DestinationPortEnd = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.Option = uint8(src[0]) + src = src[1:] + x.FlagMask = uint8(src[0]) + src = src[1:] + x.FlagCompare = uint8(src[0]) + src = src[1:] + x.InverseFlags = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTTCP) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTTCP) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTTCP) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTTCP) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTTCP) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTTCP) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTTCP) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (x *XTUDP) SizeBytes() int { + return 10 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (x *XTUDP) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.SourcePortStart)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.SourcePortEnd)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.DestinationPortStart)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(x.DestinationPortEnd)) + dst = dst[2:] + dst[0] = byte(x.InverseFlags) + dst = dst[1:] + // Padding: dst[:sizeof(uint8)] ~= uint8(0) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (x *XTUDP) UnmarshalBytes(src []byte) { + x.SourcePortStart = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.SourcePortEnd = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.DestinationPortStart = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.DestinationPortEnd = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + x.InverseFlags = uint8(src[0]) + src = src[1:] + // Padding: var _ uint8 ~= src[:sizeof(uint8)] + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (x *XTUDP) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (x *XTUDP) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(x), uintptr(x.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (x *XTUDP) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(x), unsafe.Pointer(&src[0]), uintptr(x.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (x *XTUDP) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (x *XTUDP) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return x.CopyOutN(cc, addr, x.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (x *XTUDP) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (x *XTUDP) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(x))) + hdr.Len = x.SizeBytes() + hdr.Cap = x.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that x + // must live until the use above. + runtime.KeepAlive(x) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IP6TEntry) SizeBytes() int { + return 12 + + (*IP6TIP)(nil).SizeBytes() + + 1*4 + + (*XTCounters)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IP6TEntry) MarshalBytes(dst []byte) { + i.IPv6.MarshalBytes(dst[:i.IPv6.SizeBytes()]) + dst = dst[i.IPv6.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NFCache)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.TargetOffset)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.NextOffset)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Comeback)) + dst = dst[4:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] + i.Counters.MarshalBytes(dst[:i.Counters.SizeBytes()]) + dst = dst[i.Counters.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IP6TEntry) UnmarshalBytes(src []byte) { + i.IPv6.UnmarshalBytes(src[:i.IPv6.SizeBytes()]) + src = src[i.IPv6.SizeBytes():] + i.NFCache = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.TargetOffset = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.NextOffset = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.Comeback = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: ~ copy([4]byte(i._), src[:sizeof(byte)*4]) + src = src[1*(4):] + i.Counters.UnmarshalBytes(src[:i.Counters.SizeBytes()]) + src = src[i.Counters.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IP6TEntry) Packed() bool { + return i.Counters.Packed() && i.IPv6.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IP6TEntry) MarshalUnsafe(dst []byte) { + if i.Counters.Packed() && i.IPv6.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IP6TEntry doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IP6TEntry) UnmarshalUnsafe(src []byte) { + if i.Counters.Packed() && i.IPv6.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IP6TEntry doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IP6TEntry) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Counters.Packed() && i.IPv6.Packed() { + // Type IP6TEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IP6TEntry) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IP6TEntry) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Counters.Packed() && i.IPv6.Packed() { + // Type IP6TEntry doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IP6TEntry) WriteTo(writer io.Writer) (int64, error) { + if !i.Counters.Packed() && i.IPv6.Packed() { + // Type IP6TEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IP6TIP) SizeBytes() int { + return 5 + + (*Inet6Addr)(nil).SizeBytes() + + (*Inet6Addr)(nil).SizeBytes() + + (*Inet6Addr)(nil).SizeBytes() + + (*Inet6Addr)(nil).SizeBytes() + + 1*IFNAMSIZ + + 1*IFNAMSIZ + + 1*IFNAMSIZ + + 1*IFNAMSIZ + + 1*3 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IP6TIP) MarshalBytes(dst []byte) { + i.Src.MarshalBytes(dst[:i.Src.SizeBytes()]) + dst = dst[i.Src.SizeBytes():] + i.Dst.MarshalBytes(dst[:i.Dst.SizeBytes()]) + dst = dst[i.Dst.SizeBytes():] + i.SrcMask.MarshalBytes(dst[:i.SrcMask.SizeBytes()]) + dst = dst[i.SrcMask.SizeBytes():] + i.DstMask.MarshalBytes(dst[:i.DstMask.SizeBytes()]) + dst = dst[i.DstMask.SizeBytes():] + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.InputInterface[idx]) + dst = dst[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.OutputInterface[idx]) + dst = dst[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.InputInterfaceMask[idx]) + dst = dst[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + dst[0] = byte(i.OutputInterfaceMask[idx]) + dst = dst[1:] + } + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.Protocol)) + dst = dst[2:] + dst[0] = byte(i.TOS) + dst = dst[1:] + dst[0] = byte(i.Flags) + dst = dst[1:] + dst[0] = byte(i.InverseFlags) + dst = dst[1:] + // Padding: dst[:sizeof(byte)*3] ~= [3]byte{0} + dst = dst[1*(3):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IP6TIP) UnmarshalBytes(src []byte) { + i.Src.UnmarshalBytes(src[:i.Src.SizeBytes()]) + src = src[i.Src.SizeBytes():] + i.Dst.UnmarshalBytes(src[:i.Dst.SizeBytes()]) + src = src[i.Dst.SizeBytes():] + i.SrcMask.UnmarshalBytes(src[:i.SrcMask.SizeBytes()]) + src = src[i.SrcMask.SizeBytes():] + i.DstMask.UnmarshalBytes(src[:i.DstMask.SizeBytes()]) + src = src[i.DstMask.SizeBytes():] + for idx := 0; idx < IFNAMSIZ; idx++ { + i.InputInterface[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + i.OutputInterface[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + i.InputInterfaceMask[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < IFNAMSIZ; idx++ { + i.OutputInterfaceMask[idx] = src[0] + src = src[1:] + } + i.Protocol = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.TOS = uint8(src[0]) + src = src[1:] + i.Flags = uint8(src[0]) + src = src[1:] + i.InverseFlags = uint8(src[0]) + src = src[1:] + // Padding: ~ copy([3]byte(i._), src[:sizeof(byte)*3]) + src = src[1*(3):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IP6TIP) Packed() bool { + return i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IP6TIP) MarshalUnsafe(dst []byte) { + if i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IP6TIP doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IP6TIP) UnmarshalUnsafe(src []byte) { + if i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IP6TIP doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IP6TIP) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + // Type IP6TIP doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IP6TIP) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IP6TIP) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + // Type IP6TIP doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IP6TIP) WriteTo(writer io.Writer) (int64, error) { + if !i.Dst.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.SrcMask.Packed() { + // Type IP6TIP doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *IP6TReplace) SizeBytes() int { + return 24 + + (*TableName)(nil).SizeBytes() + + 4*NF_INET_NUMHOOKS + + 4*NF_INET_NUMHOOKS +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *IP6TReplace) MarshalBytes(dst []byte) { + i.Name.MarshalBytes(dst[:i.Name.SizeBytes()]) + dst = dst[i.Name.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.ValidHooks)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NumEntries)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Size)) + dst = dst[4:] + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.HookEntry[idx])) + dst = dst[4:] + } + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Underflow[idx])) + dst = dst[4:] + } + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.NumCounters)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(i.Counters)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *IP6TReplace) UnmarshalBytes(src []byte) { + i.Name.UnmarshalBytes(src[:i.Name.SizeBytes()]) + src = src[i.Name.SizeBytes():] + i.ValidHooks = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.NumEntries = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + i.HookEntry[idx] = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + for idx := 0; idx < NF_INET_NUMHOOKS; idx++ { + i.Underflow[idx] = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + i.NumCounters = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Counters = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *IP6TReplace) Packed() bool { + return i.Name.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *IP6TReplace) MarshalUnsafe(dst []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type IP6TReplace doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *IP6TReplace) UnmarshalUnsafe(src []byte) { + if i.Name.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type IP6TReplace doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *IP6TReplace) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Name.Packed() { + // Type IP6TReplace doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *IP6TReplace) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *IP6TReplace) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Name.Packed() { + // Type IP6TReplace doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *IP6TReplace) WriteTo(writer io.Writer) (int64, error) { + if !i.Name.Packed() { + // Type IP6TReplace doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (ke *KernelIP6TEntry) Packed() bool { + // Type KernelIP6TEntry is dynamic so it might have slice/string headers. Hence, it is not packed. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIP6TEntry) MarshalUnsafe(dst []byte) { + // Type KernelIP6TEntry doesn't have a packed layout in memory, fallback to MarshalBytes. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIP6TEntry) UnmarshalUnsafe(src []byte) { + // Type KernelIP6TEntry doesn't have a packed layout in memory, fallback to UnmarshalBytes. + ke.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (ke *KernelIP6TEntry) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type KernelIP6TEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + ke.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (ke *KernelIP6TEntry) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return ke.CopyOutN(cc, addr, ke.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (ke *KernelIP6TEntry) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type KernelIP6TEntry doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIP6TEntry) WriteTo(writer io.Writer) (int64, error) { + // Type KernelIP6TEntry doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (ke *KernelIP6TGetEntries) Packed() bool { + // Type KernelIP6TGetEntries is dynamic so it might have slice/string headers. Hence, it is not packed. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIP6TGetEntries) MarshalUnsafe(dst []byte) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fallback to MarshalBytes. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fallback to UnmarshalBytes. + ke.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + ke.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return ke.CopyOutN(cc, addr, ke.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIP6TGetEntries) WriteTo(writer io.Writer) (int64, error) { + // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *NFNATRange) SizeBytes() int { + return 8 + + (*Inet6Addr)(nil).SizeBytes() + + (*Inet6Addr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NFNATRange) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.Flags)) + dst = dst[4:] + n.MinAddr.MarshalBytes(dst[:n.MinAddr.SizeBytes()]) + dst = dst[n.MinAddr.SizeBytes():] + n.MaxAddr.MarshalBytes(dst[:n.MaxAddr.SizeBytes()]) + dst = dst[n.MaxAddr.SizeBytes():] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.MinProto)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.MaxProto)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NFNATRange) UnmarshalBytes(src []byte) { + n.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + n.MinAddr.UnmarshalBytes(src[:n.MinAddr.SizeBytes()]) + src = src[n.MinAddr.SizeBytes():] + n.MaxAddr.UnmarshalBytes(src[:n.MaxAddr.SizeBytes()]) + src = src[n.MaxAddr.SizeBytes():] + n.MinProto = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + n.MaxProto = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NFNATRange) Packed() bool { + return n.MaxAddr.Packed() && n.MinAddr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NFNATRange) MarshalUnsafe(dst []byte) { + if n.MaxAddr.Packed() && n.MinAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) + } else { + // Type NFNATRange doesn't have a packed layout in memory, fallback to MarshalBytes. + n.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NFNATRange) UnmarshalUnsafe(src []byte) { + if n.MaxAddr.Packed() && n.MinAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) + } else { + // Type NFNATRange doesn't have a packed layout in memory, fallback to UnmarshalBytes. + n.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NFNATRange) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !n.MaxAddr.Packed() && n.MinAddr.Packed() { + // Type NFNATRange doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + n.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NFNATRange) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NFNATRange) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !n.MaxAddr.Packed() && n.MinAddr.Packed() { + // Type NFNATRange doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + n.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NFNATRange) WriteTo(writer io.Writer) (int64, error) { + if !n.MaxAddr.Packed() && n.MinAddr.Packed() { + // Type NFNATRange doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, n.SizeBytes()) + n.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *NetlinkAttrHeader) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NetlinkAttrHeader) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.Length)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.Type)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NetlinkAttrHeader) UnmarshalBytes(src []byte) { + n.Length = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + n.Type = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NetlinkAttrHeader) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NetlinkAttrHeader) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NetlinkAttrHeader) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NetlinkAttrHeader) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NetlinkAttrHeader) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NetlinkAttrHeader) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NetlinkAttrHeader) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *NetlinkErrorMessage) SizeBytes() int { + return 4 + + (*NetlinkMessageHeader)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NetlinkErrorMessage) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.Error)) + dst = dst[4:] + n.Header.MarshalBytes(dst[:n.Header.SizeBytes()]) + dst = dst[n.Header.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NetlinkErrorMessage) UnmarshalBytes(src []byte) { + n.Error = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + n.Header.UnmarshalBytes(src[:n.Header.SizeBytes()]) + src = src[n.Header.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NetlinkErrorMessage) Packed() bool { + return n.Header.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NetlinkErrorMessage) MarshalUnsafe(dst []byte) { + if n.Header.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) + } else { + // Type NetlinkErrorMessage doesn't have a packed layout in memory, fallback to MarshalBytes. + n.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NetlinkErrorMessage) UnmarshalUnsafe(src []byte) { + if n.Header.Packed() { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) + } else { + // Type NetlinkErrorMessage doesn't have a packed layout in memory, fallback to UnmarshalBytes. + n.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NetlinkErrorMessage) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !n.Header.Packed() { + // Type NetlinkErrorMessage doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + n.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NetlinkErrorMessage) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NetlinkErrorMessage) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !n.Header.Packed() { + // Type NetlinkErrorMessage doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + n.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NetlinkErrorMessage) WriteTo(writer io.Writer) (int64, error) { + if !n.Header.Packed() { + // Type NetlinkErrorMessage doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, n.SizeBytes()) + n.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *NetlinkMessageHeader) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *NetlinkMessageHeader) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.Length)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.Type)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(n.Flags)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.Seq)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(n.PortID)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *NetlinkMessageHeader) UnmarshalBytes(src []byte) { + n.Length = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + n.Type = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + n.Flags = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + n.Seq = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + n.PortID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *NetlinkMessageHeader) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *NetlinkMessageHeader) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *NetlinkMessageHeader) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *NetlinkMessageHeader) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *NetlinkMessageHeader) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *NetlinkMessageHeader) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *NetlinkMessageHeader) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockAddrNetlink) SizeBytes() int { + return 12 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockAddrNetlink) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Family)) + dst = dst[2:] + // Padding: dst[:sizeof(uint16)] ~= uint16(0) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.PortID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Groups)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockAddrNetlink) UnmarshalBytes(src []byte) { + s.Family = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: var _ uint16 ~= src[:sizeof(uint16)] + src = src[2:] + s.PortID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Groups = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockAddrNetlink) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockAddrNetlink) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockAddrNetlink) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockAddrNetlink) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockAddrNetlink) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockAddrNetlink) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockAddrNetlink) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *InterfaceAddrMessage) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *InterfaceAddrMessage) MarshalBytes(dst []byte) { + dst[0] = byte(i.Family) + dst = dst[1:] + dst[0] = byte(i.PrefixLen) + dst = dst[1:] + dst[0] = byte(i.Flags) + dst = dst[1:] + dst[0] = byte(i.Scope) + dst = dst[1:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Index)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *InterfaceAddrMessage) UnmarshalBytes(src []byte) { + i.Family = uint8(src[0]) + src = src[1:] + i.PrefixLen = uint8(src[0]) + src = src[1:] + i.Flags = uint8(src[0]) + src = src[1:] + i.Scope = uint8(src[0]) + src = src[1:] + i.Index = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *InterfaceAddrMessage) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *InterfaceAddrMessage) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *InterfaceAddrMessage) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *InterfaceAddrMessage) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *InterfaceAddrMessage) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *InterfaceAddrMessage) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *InterfaceAddrMessage) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *InterfaceInfoMessage) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *InterfaceInfoMessage) MarshalBytes(dst []byte) { + dst[0] = byte(i.Family) + dst = dst[1:] + // Padding: dst[:sizeof(uint8)] ~= uint8(0) + dst = dst[1:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(i.Type)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Index)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.Change)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *InterfaceInfoMessage) UnmarshalBytes(src []byte) { + i.Family = uint8(src[0]) + src = src[1:] + // Padding: var _ uint8 ~= src[:sizeof(uint8)] + src = src[1:] + i.Type = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + i.Index = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + i.Change = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *InterfaceInfoMessage) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *InterfaceInfoMessage) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *InterfaceInfoMessage) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *InterfaceInfoMessage) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *InterfaceInfoMessage) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *InterfaceInfoMessage) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *InterfaceInfoMessage) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RouteMessage) SizeBytes() int { + return 12 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RouteMessage) MarshalBytes(dst []byte) { + dst[0] = byte(r.Family) + dst = dst[1:] + dst[0] = byte(r.DstLen) + dst = dst[1:] + dst[0] = byte(r.SrcLen) + dst = dst[1:] + dst[0] = byte(r.TOS) + dst = dst[1:] + dst[0] = byte(r.Table) + dst = dst[1:] + dst[0] = byte(r.Protocol) + dst = dst[1:] + dst[0] = byte(r.Scope) + dst = dst[1:] + dst[0] = byte(r.Type) + dst = dst[1:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(r.Flags)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RouteMessage) UnmarshalBytes(src []byte) { + r.Family = uint8(src[0]) + src = src[1:] + r.DstLen = uint8(src[0]) + src = src[1:] + r.SrcLen = uint8(src[0]) + src = src[1:] + r.TOS = uint8(src[0]) + src = src[1:] + r.Table = uint8(src[0]) + src = src[1:] + r.Protocol = uint8(src[0]) + src = src[1:] + r.Scope = uint8(src[0]) + src = src[1:] + r.Type = uint8(src[0]) + src = src[1:] + r.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (r *RouteMessage) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (r *RouteMessage) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(r), uintptr(r.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (r *RouteMessage) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(r), unsafe.Pointer(&src[0]), uintptr(r.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (r *RouteMessage) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (r *RouteMessage) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return r.CopyOutN(cc, addr, r.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (r *RouteMessage) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (r *RouteMessage) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (p *PollFD) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (p *PollFD) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(p.FD)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(p.Events)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(p.REvents)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (p *PollFD) UnmarshalBytes(src []byte) { + p.FD = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + p.Events = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + p.REvents = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (p *PollFD) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (p *PollFD) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(p), uintptr(p.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (p *PollFD) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(p), unsafe.Pointer(&src[0]), uintptr(p.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (p *PollFD) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (p *PollFD) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return p.CopyOutN(cc, addr, p.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (p *PollFD) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (p *PollFD) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyPollFDSliceIn copies in a slice of PollFD objects from the task's memory. +func CopyPollFDSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []PollFD) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*PollFD)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyPollFDSliceOut copies a slice of PollFD objects to the task's memory. +func CopyPollFDSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []PollFD) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*PollFD)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafePollFDSlice is like PollFD.MarshalUnsafe, but for a []PollFD. +func MarshalUnsafePollFDSlice(src []PollFD, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*PollFD)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafePollFDSlice is like PollFD.UnmarshalUnsafe, but for a []PollFD. +func UnmarshalUnsafePollFDSlice(dst []PollFD, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*PollFD)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *RSeqCriticalSection) SizeBytes() int { + return 32 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *RSeqCriticalSection) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(r.Version)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(r.Flags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.Start)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.PostCommitOffset)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.Abort)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *RSeqCriticalSection) UnmarshalBytes(src []byte) { + r.Version = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + r.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + r.Start = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.PostCommitOffset = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.Abort = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (r *RSeqCriticalSection) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (r *RSeqCriticalSection) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(r), uintptr(r.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (r *RSeqCriticalSection) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(r), unsafe.Pointer(&src[0]), uintptr(r.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (r *RSeqCriticalSection) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (r *RSeqCriticalSection) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return r.CopyOutN(cc, addr, r.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (r *RSeqCriticalSection) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (r *RSeqCriticalSection) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *Rusage) SizeBytes() int { + return 112 + + (*Timeval)(nil).SizeBytes() + + (*Timeval)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *Rusage) MarshalBytes(dst []byte) { + r.UTime.MarshalBytes(dst[:r.UTime.SizeBytes()]) + dst = dst[r.UTime.SizeBytes():] + r.STime.MarshalBytes(dst[:r.STime.SizeBytes()]) + dst = dst[r.STime.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.MaxRSS)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.IXRSS)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.IDRSS)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.ISRSS)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.MinFlt)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.MajFlt)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.NSwap)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.InBlock)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.OuBlock)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.MsgSnd)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.MsgRcv)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.NSignals)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.NVCSw)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.NIvCSw)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *Rusage) UnmarshalBytes(src []byte) { + r.UTime.UnmarshalBytes(src[:r.UTime.SizeBytes()]) + src = src[r.UTime.SizeBytes():] + r.STime.UnmarshalBytes(src[:r.STime.SizeBytes()]) + src = src[r.STime.SizeBytes():] + r.MaxRSS = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.IXRSS = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.IDRSS = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.ISRSS = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.MinFlt = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.MajFlt = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.NSwap = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.InBlock = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.OuBlock = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.MsgSnd = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.MsgRcv = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.NSignals = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.NVCSw = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.NIvCSw = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (r *Rusage) Packed() bool { + return r.STime.Packed() && r.UTime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (r *Rusage) MarshalUnsafe(dst []byte) { + if r.STime.Packed() && r.UTime.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(r), uintptr(r.SizeBytes())) + } else { + // Type Rusage doesn't have a packed layout in memory, fallback to MarshalBytes. + r.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (r *Rusage) UnmarshalUnsafe(src []byte) { + if r.STime.Packed() && r.UTime.Packed() { + gohacks.Memmove(unsafe.Pointer(r), unsafe.Pointer(&src[0]), uintptr(r.SizeBytes())) + } else { + // Type Rusage doesn't have a packed layout in memory, fallback to UnmarshalBytes. + r.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (r *Rusage) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !r.STime.Packed() && r.UTime.Packed() { + // Type Rusage doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(r.SizeBytes()) // escapes: okay. + r.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (r *Rusage) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return r.CopyOutN(cc, addr, r.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (r *Rusage) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !r.STime.Packed() && r.UTime.Packed() { + // Type Rusage doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(r.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + r.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (r *Rusage) WriteTo(writer io.Writer) (int64, error) { + if !r.STime.Packed() && r.UTime.Packed() { + // Type Rusage doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, r.SizeBytes()) + r.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SeccompData) SizeBytes() int { + return 16 + + 8*6 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SeccompData) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Nr)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Arch)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.InstructionPointer)) + dst = dst[8:] + for idx := 0; idx < 6; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Args[idx])) + dst = dst[8:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SeccompData) UnmarshalBytes(src []byte) { + s.Nr = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Arch = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.InstructionPointer = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 6; idx++ { + s.Args[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SeccompData) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SeccompData) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SeccompData) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SeccompData) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SeccompData) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SeccompData) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SeccompData) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SemInfo) SizeBytes() int { + return 40 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SemInfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemMap)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemMni)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemMns)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemMnu)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemMsl)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemOpm)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemUme)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemUsz)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemVmx)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.SemAem)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SemInfo) UnmarshalBytes(src []byte) { + s.SemMap = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemMni = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemMns = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemMnu = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemMsl = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemOpm = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemUme = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemUsz = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemVmx = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.SemAem = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SemInfo) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SemInfo) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SemInfo) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SemInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SemInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SemInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SemInfo) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Sembuf) SizeBytes() int { + return 6 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Sembuf) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.SemNum)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.SemOp)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.SemFlg)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Sembuf) UnmarshalBytes(src []byte) { + s.SemNum = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.SemOp = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.SemFlg = int16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Sembuf) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Sembuf) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Sembuf) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Sembuf) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Sembuf) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Sembuf) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Sembuf) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopySembufSliceIn copies in a slice of Sembuf objects from the task's memory. +func CopySembufSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Sembuf) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Sembuf)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopySembufSliceOut copies a slice of Sembuf objects to the task's memory. +func CopySembufSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []Sembuf) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Sembuf)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeSembufSlice is like Sembuf.MarshalUnsafe, but for a []Sembuf. +func MarshalUnsafeSembufSlice(src []Sembuf, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Sembuf)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeSembufSlice is like Sembuf.UnmarshalUnsafe, but for a []Sembuf. +func UnmarshalUnsafeSembufSlice(dst []Sembuf, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Sembuf)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *ShmInfo) SizeBytes() int { + return 44 + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *ShmInfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.UsedIDs)) + dst = dst[4:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmTot)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmRss)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmSwp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.SwapAttempts)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.SwapSuccesses)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *ShmInfo) UnmarshalBytes(src []byte) { + s.UsedIDs = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: ~ copy([4]byte(s._), src[:sizeof(byte)*4]) + src = src[1*(4):] + s.ShmTot = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmRss = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmSwp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.SwapAttempts = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.SwapSuccesses = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *ShmInfo) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *ShmInfo) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *ShmInfo) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *ShmInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *ShmInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *ShmInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *ShmInfo) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *ShmParams) SizeBytes() int { + return 40 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *ShmParams) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmMax)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmMin)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmMni)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmSeg)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmAll)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *ShmParams) UnmarshalBytes(src []byte) { + s.ShmMax = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmMin = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmMni = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmSeg = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmAll = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *ShmParams) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *ShmParams) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *ShmParams) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *ShmParams) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *ShmParams) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *ShmParams) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *ShmParams) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *ShmidDS) SizeBytes() int { + return 40 + + (*IPCPerm)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *ShmidDS) MarshalBytes(dst []byte) { + s.ShmPerm.MarshalBytes(dst[:s.ShmPerm.SizeBytes()]) + dst = dst[s.ShmPerm.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmSegsz)) + dst = dst[8:] + s.ShmAtime.MarshalBytes(dst[:s.ShmAtime.SizeBytes()]) + dst = dst[s.ShmAtime.SizeBytes():] + s.ShmDtime.MarshalBytes(dst[:s.ShmDtime.SizeBytes()]) + dst = dst[s.ShmDtime.SizeBytes():] + s.ShmCtime.MarshalBytes(dst[:s.ShmCtime.SizeBytes()]) + dst = dst[s.ShmCtime.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.ShmCpid)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.ShmLpid)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.ShmNattach)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Unused4)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Unused5)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *ShmidDS) UnmarshalBytes(src []byte) { + s.ShmPerm.UnmarshalBytes(src[:s.ShmPerm.SizeBytes()]) + src = src[s.ShmPerm.SizeBytes():] + s.ShmSegsz = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ShmAtime.UnmarshalBytes(src[:s.ShmAtime.SizeBytes()]) + src = src[s.ShmAtime.SizeBytes():] + s.ShmDtime.UnmarshalBytes(src[:s.ShmDtime.SizeBytes()]) + src = src[s.ShmDtime.SizeBytes():] + s.ShmCtime.UnmarshalBytes(src[:s.ShmCtime.SizeBytes()]) + src = src[s.ShmCtime.SizeBytes():] + s.ShmCpid = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.ShmLpid = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.ShmNattach = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Unused4 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Unused5 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *ShmidDS) Packed() bool { + return s.ShmAtime.Packed() && s.ShmCtime.Packed() && s.ShmDtime.Packed() && s.ShmPerm.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *ShmidDS) MarshalUnsafe(dst []byte) { + if s.ShmAtime.Packed() && s.ShmCtime.Packed() && s.ShmDtime.Packed() && s.ShmPerm.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type ShmidDS doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *ShmidDS) UnmarshalUnsafe(src []byte) { + if s.ShmAtime.Packed() && s.ShmCtime.Packed() && s.ShmDtime.Packed() && s.ShmPerm.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type ShmidDS doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *ShmidDS) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.ShmAtime.Packed() && s.ShmCtime.Packed() && s.ShmDtime.Packed() && s.ShmPerm.Packed() { + // Type ShmidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *ShmidDS) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *ShmidDS) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.ShmAtime.Packed() && s.ShmCtime.Packed() && s.ShmDtime.Packed() && s.ShmPerm.Packed() { + // Type ShmidDS doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *ShmidDS) WriteTo(writer io.Writer) (int64, error) { + if !s.ShmAtime.Packed() && s.ShmCtime.Packed() && s.ShmDtime.Packed() && s.ShmPerm.Packed() { + // Type ShmidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SigAction) SizeBytes() int { + return 24 + + (*SignalSet)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SigAction) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Handler)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Flags)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Restorer)) + dst = dst[8:] + s.Mask.MarshalBytes(dst[:s.Mask.SizeBytes()]) + dst = dst[s.Mask.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SigAction) UnmarshalBytes(src []byte) { + s.Handler = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Flags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Restorer = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Mask.UnmarshalBytes(src[:s.Mask.SizeBytes()]) + src = src[s.Mask.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SigAction) Packed() bool { + return s.Mask.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SigAction) MarshalUnsafe(dst []byte) { + if s.Mask.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SigAction doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SigAction) UnmarshalUnsafe(src []byte) { + if s.Mask.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SigAction doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SigAction) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Mask.Packed() { + // Type SigAction doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SigAction) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SigAction) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Mask.Packed() { + // Type SigAction doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SigAction) WriteTo(writer io.Writer) (int64, error) { + if !s.Mask.Packed() { + // Type SigAction doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Sigevent) SizeBytes() int { + return 20 + + 1*44 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Sigevent) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Value)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Signo)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Notify)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Tid)) + dst = dst[4:] + for idx := 0; idx < 44; idx++ { + dst[0] = byte(s.UnRemainder[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Sigevent) UnmarshalBytes(src []byte) { + s.Value = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Signo = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Notify = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Tid = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 44; idx++ { + s.UnRemainder[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Sigevent) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Sigevent) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Sigevent) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Sigevent) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Sigevent) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Sigevent) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Sigevent) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SignalInfo) SizeBytes() int { + return 16 + + 1*(128-16) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalInfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Signo)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Errno)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Code)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + for idx := 0; idx < (128-16); idx++ { + dst[0] = byte(s.Fields[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalInfo) UnmarshalBytes(src []byte) { + s.Signo = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Errno = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Code = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + for idx := 0; idx < (128-16); idx++ { + s.Fields[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SignalInfo) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalInfo) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalInfo) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SignalInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SignalInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SignalInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalInfo) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (s *SignalSet) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalSet) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(*s)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalSet) UnmarshalBytes(src []byte) { + *s = SignalSet(uint64(hostarch.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SignalSet) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalSet) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalSet) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SignalSet) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SignalSet) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SignalSet) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalSet) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SignalStack) SizeBytes() int { + return 24 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalStack) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Addr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Flags)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalStack) UnmarshalBytes(src []byte) { + s.Addr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + s.Size = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SignalStack) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalStack) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalStack) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SignalStack) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SignalStack) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SignalStack) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalStack) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SignalfdSiginfo) SizeBytes() int { + return 82 + + 1*48 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalfdSiginfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Signo)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Errno)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Code)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.PID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.FD)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.TID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Band)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Overrun)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.TrapNo)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Status)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Int)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Ptr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.UTime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.STime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Addr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.AddrLSB)) + dst = dst[2:] + // Padding: dst[:sizeof(uint8)*48] ~= [48]uint8{0} + dst = dst[1*(48):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalfdSiginfo) UnmarshalBytes(src []byte) { + s.Signo = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Errno = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Code = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.PID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.FD = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.TID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Band = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Overrun = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.TrapNo = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Status = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Int = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Ptr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.UTime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.STime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Addr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.AddrLSB = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: ~ copy([48]uint8(s._), src[:sizeof(uint8)*48]) + src = src[1*(48):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SignalfdSiginfo) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalfdSiginfo) MarshalUnsafe(dst []byte) { + // Type SignalfdSiginfo doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalfdSiginfo) UnmarshalUnsafe(src []byte) { + // Type SignalfdSiginfo doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SignalfdSiginfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type SignalfdSiginfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SignalfdSiginfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SignalfdSiginfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type SignalfdSiginfo doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalfdSiginfo) WriteTo(writer io.Writer) (int64, error) { + // Type SignalfdSiginfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *ControlMessageCredentials) SizeBytes() int { + return 12 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *ControlMessageCredentials) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.PID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.GID)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *ControlMessageCredentials) UnmarshalBytes(src []byte) { + c.PID = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (c *ControlMessageCredentials) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (c *ControlMessageCredentials) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(c), uintptr(c.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (c *ControlMessageCredentials) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(c), unsafe.Pointer(&src[0]), uintptr(c.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (c *ControlMessageCredentials) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (c *ControlMessageCredentials) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return c.CopyOutN(cc, addr, c.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (c *ControlMessageCredentials) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (c *ControlMessageCredentials) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *ControlMessageHeader) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *ControlMessageHeader) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(c.Length)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Level)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.Type)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *ControlMessageHeader) UnmarshalBytes(src []byte) { + c.Length = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + c.Level = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.Type = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (c *ControlMessageHeader) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (c *ControlMessageHeader) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(c), uintptr(c.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (c *ControlMessageHeader) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(c), unsafe.Pointer(&src[0]), uintptr(c.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (c *ControlMessageHeader) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (c *ControlMessageHeader) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return c.CopyOutN(cc, addr, c.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (c *ControlMessageHeader) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (c *ControlMessageHeader) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (c *ControlMessageIPPacketInfo) SizeBytes() int { + return 4 + + (*InetAddr)(nil).SizeBytes() + + (*InetAddr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *ControlMessageIPPacketInfo) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(c.NIC)) + dst = dst[4:] + c.LocalAddr.MarshalBytes(dst[:c.LocalAddr.SizeBytes()]) + dst = dst[c.LocalAddr.SizeBytes():] + c.DestinationAddr.MarshalBytes(dst[:c.DestinationAddr.SizeBytes()]) + dst = dst[c.DestinationAddr.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *ControlMessageIPPacketInfo) UnmarshalBytes(src []byte) { + c.NIC = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + c.LocalAddr.UnmarshalBytes(src[:c.LocalAddr.SizeBytes()]) + src = src[c.LocalAddr.SizeBytes():] + c.DestinationAddr.UnmarshalBytes(src[:c.DestinationAddr.SizeBytes()]) + src = src[c.DestinationAddr.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (c *ControlMessageIPPacketInfo) Packed() bool { + return c.DestinationAddr.Packed() && c.LocalAddr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (c *ControlMessageIPPacketInfo) MarshalUnsafe(dst []byte) { + if c.DestinationAddr.Packed() && c.LocalAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(c), uintptr(c.SizeBytes())) + } else { + // Type ControlMessageIPPacketInfo doesn't have a packed layout in memory, fallback to MarshalBytes. + c.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (c *ControlMessageIPPacketInfo) UnmarshalUnsafe(src []byte) { + if c.DestinationAddr.Packed() && c.LocalAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(c), unsafe.Pointer(&src[0]), uintptr(c.SizeBytes())) + } else { + // Type ControlMessageIPPacketInfo doesn't have a packed layout in memory, fallback to UnmarshalBytes. + c.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (c *ControlMessageIPPacketInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !c.DestinationAddr.Packed() && c.LocalAddr.Packed() { + // Type ControlMessageIPPacketInfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(c.SizeBytes()) // escapes: okay. + c.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (c *ControlMessageIPPacketInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return c.CopyOutN(cc, addr, c.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (c *ControlMessageIPPacketInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !c.DestinationAddr.Packed() && c.LocalAddr.Packed() { + // Type ControlMessageIPPacketInfo doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(c.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + c.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (c *ControlMessageIPPacketInfo) WriteTo(writer io.Writer) (int64, error) { + if !c.DestinationAddr.Packed() && c.LocalAddr.Packed() { + // Type ControlMessageIPPacketInfo doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, c.SizeBytes()) + c.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (i *Inet6Addr) SizeBytes() int { + return 1 * 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Inet6Addr) MarshalBytes(dst []byte) { + for idx := 0; idx < 16; idx++ { + dst[0] = byte(i[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Inet6Addr) UnmarshalBytes(src []byte) { + for idx := 0; idx < 16; idx++ { + i[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Inet6Addr) Packed() bool { + // Array newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Inet6Addr) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&i[0]), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Inet6Addr) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Inet6Addr) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Inet6Addr) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Inet6Addr) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Inet6Addr) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *Inet6MulticastRequest) SizeBytes() int { + return 4 + + (*Inet6Addr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Inet6MulticastRequest) MarshalBytes(dst []byte) { + i.MulticastAddr.MarshalBytes(dst[:i.MulticastAddr.SizeBytes()]) + dst = dst[i.MulticastAddr.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.InterfaceIndex)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Inet6MulticastRequest) UnmarshalBytes(src []byte) { + i.MulticastAddr.UnmarshalBytes(src[:i.MulticastAddr.SizeBytes()]) + src = src[i.MulticastAddr.SizeBytes():] + i.InterfaceIndex = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Inet6MulticastRequest) Packed() bool { + return i.MulticastAddr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Inet6MulticastRequest) MarshalUnsafe(dst []byte) { + if i.MulticastAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type Inet6MulticastRequest doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Inet6MulticastRequest) UnmarshalUnsafe(src []byte) { + if i.MulticastAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type Inet6MulticastRequest doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Inet6MulticastRequest) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.MulticastAddr.Packed() { + // Type Inet6MulticastRequest doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Inet6MulticastRequest) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Inet6MulticastRequest) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.MulticastAddr.Packed() { + // Type Inet6MulticastRequest doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Inet6MulticastRequest) WriteTo(writer io.Writer) (int64, error) { + if !i.MulticastAddr.Packed() { + // Type Inet6MulticastRequest doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (i *InetAddr) SizeBytes() int { + return 1 * 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *InetAddr) MarshalBytes(dst []byte) { + for idx := 0; idx < 4; idx++ { + dst[0] = byte(i[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *InetAddr) UnmarshalBytes(src []byte) { + for idx := 0; idx < 4; idx++ { + i[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *InetAddr) Packed() bool { + // Array newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *InetAddr) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&i[0]), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *InetAddr) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *InetAddr) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *InetAddr) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *InetAddr) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *InetAddr) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *InetMulticastRequest) SizeBytes() int { + return 0 + + (*InetAddr)(nil).SizeBytes() + + (*InetAddr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *InetMulticastRequest) MarshalBytes(dst []byte) { + i.MulticastAddr.MarshalBytes(dst[:i.MulticastAddr.SizeBytes()]) + dst = dst[i.MulticastAddr.SizeBytes():] + i.InterfaceAddr.MarshalBytes(dst[:i.InterfaceAddr.SizeBytes()]) + dst = dst[i.InterfaceAddr.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *InetMulticastRequest) UnmarshalBytes(src []byte) { + i.MulticastAddr.UnmarshalBytes(src[:i.MulticastAddr.SizeBytes()]) + src = src[i.MulticastAddr.SizeBytes():] + i.InterfaceAddr.UnmarshalBytes(src[:i.InterfaceAddr.SizeBytes()]) + src = src[i.InterfaceAddr.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *InetMulticastRequest) Packed() bool { + return i.InterfaceAddr.Packed() && i.MulticastAddr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *InetMulticastRequest) MarshalUnsafe(dst []byte) { + if i.InterfaceAddr.Packed() && i.MulticastAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type InetMulticastRequest doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *InetMulticastRequest) UnmarshalUnsafe(src []byte) { + if i.InterfaceAddr.Packed() && i.MulticastAddr.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type InetMulticastRequest doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *InetMulticastRequest) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.InterfaceAddr.Packed() && i.MulticastAddr.Packed() { + // Type InetMulticastRequest doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *InetMulticastRequest) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *InetMulticastRequest) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.InterfaceAddr.Packed() && i.MulticastAddr.Packed() { + // Type InetMulticastRequest doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *InetMulticastRequest) WriteTo(writer io.Writer) (int64, error) { + if !i.InterfaceAddr.Packed() && i.MulticastAddr.Packed() { + // Type InetMulticastRequest doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *InetMulticastRequestWithNIC) SizeBytes() int { + return 4 + + (*InetMulticastRequest)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *InetMulticastRequestWithNIC) MarshalBytes(dst []byte) { + i.InetMulticastRequest.MarshalBytes(dst[:i.InetMulticastRequest.SizeBytes()]) + dst = dst[i.InetMulticastRequest.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(i.InterfaceIndex)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *InetMulticastRequestWithNIC) UnmarshalBytes(src []byte) { + i.InetMulticastRequest.UnmarshalBytes(src[:i.InetMulticastRequest.SizeBytes()]) + src = src[i.InetMulticastRequest.SizeBytes():] + i.InterfaceIndex = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *InetMulticastRequestWithNIC) Packed() bool { + return i.InetMulticastRequest.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *InetMulticastRequestWithNIC) MarshalUnsafe(dst []byte) { + if i.InetMulticastRequest.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type InetMulticastRequestWithNIC doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *InetMulticastRequestWithNIC) UnmarshalUnsafe(src []byte) { + if i.InetMulticastRequest.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type InetMulticastRequestWithNIC doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *InetMulticastRequestWithNIC) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.InetMulticastRequest.Packed() { + // Type InetMulticastRequestWithNIC doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *InetMulticastRequestWithNIC) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *InetMulticastRequestWithNIC) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.InetMulticastRequest.Packed() { + // Type InetMulticastRequestWithNIC doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *InetMulticastRequestWithNIC) WriteTo(writer io.Writer) (int64, error) { + if !i.InetMulticastRequest.Packed() { + // Type InetMulticastRequestWithNIC doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (l *Linger) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (l *Linger) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(l.OnOff)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(l.Linger)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (l *Linger) UnmarshalBytes(src []byte) { + l.OnOff = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + l.Linger = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (l *Linger) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (l *Linger) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(l), uintptr(l.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (l *Linger) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(l), unsafe.Pointer(&src[0]), uintptr(l.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (l *Linger) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(l))) + hdr.Len = l.SizeBytes() + hdr.Cap = l.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that l + // must live until the use above. + runtime.KeepAlive(l) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (l *Linger) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return l.CopyOutN(cc, addr, l.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (l *Linger) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(l))) + hdr.Len = l.SizeBytes() + hdr.Cap = l.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that l + // must live until the use above. + runtime.KeepAlive(l) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (l *Linger) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(l))) + hdr.Len = l.SizeBytes() + hdr.Cap = l.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that l + // must live until the use above. + runtime.KeepAlive(l) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockAddrInet) SizeBytes() int { + return 4 + + (*InetAddr)(nil).SizeBytes() + + 1*8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockAddrInet) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Family)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Port)) + dst = dst[2:] + s.Addr.MarshalBytes(dst[:s.Addr.SizeBytes()]) + dst = dst[s.Addr.SizeBytes():] + // Padding: dst[:sizeof(uint8)*8] ~= [8]uint8{0} + dst = dst[1*(8):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockAddrInet) UnmarshalBytes(src []byte) { + s.Family = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Port = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Addr.UnmarshalBytes(src[:s.Addr.SizeBytes()]) + src = src[s.Addr.SizeBytes():] + // Padding: ~ copy([8]uint8(s._), src[:sizeof(uint8)*8]) + src = src[1*(8):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockAddrInet) Packed() bool { + return s.Addr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockAddrInet) MarshalUnsafe(dst []byte) { + if s.Addr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SockAddrInet doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockAddrInet) UnmarshalUnsafe(src []byte) { + if s.Addr.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SockAddrInet doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockAddrInet) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Addr.Packed() { + // Type SockAddrInet doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockAddrInet) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockAddrInet) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Addr.Packed() { + // Type SockAddrInet doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockAddrInet) WriteTo(writer io.Writer) (int64, error) { + if !s.Addr.Packed() { + // Type SockAddrInet doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockAddrInet6) SizeBytes() int { + return 12 + + 1*16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockAddrInet6) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Family)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Port)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Flowinfo)) + dst = dst[4:] + for idx := 0; idx < 16; idx++ { + dst[0] = byte(s.Addr[idx]) + dst = dst[1:] + } + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Scope_id)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockAddrInet6) UnmarshalBytes(src []byte) { + s.Family = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Port = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Flowinfo = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 16; idx++ { + s.Addr[idx] = src[0] + src = src[1:] + } + s.Scope_id = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockAddrInet6) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockAddrInet6) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockAddrInet6) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockAddrInet6) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockAddrInet6) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockAddrInet6) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockAddrInet6) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockAddrLink) SizeBytes() int { + return 12 + + 1*8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockAddrLink) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Family)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Protocol)) + dst = dst[2:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.InterfaceIndex)) + dst = dst[4:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.ARPHardwareType)) + dst = dst[2:] + dst[0] = byte(s.PacketType) + dst = dst[1:] + dst[0] = byte(s.HardwareAddrLen) + dst = dst[1:] + for idx := 0; idx < 8; idx++ { + dst[0] = byte(s.HardwareAddr[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockAddrLink) UnmarshalBytes(src []byte) { + s.Family = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Protocol = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.InterfaceIndex = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.ARPHardwareType = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.PacketType = src[0] + src = src[1:] + s.HardwareAddrLen = src[0] + src = src[1:] + for idx := 0; idx < 8; idx++ { + s.HardwareAddr[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockAddrLink) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockAddrLink) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockAddrLink) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockAddrLink) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockAddrLink) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockAddrLink) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockAddrLink) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SockAddrUnix) SizeBytes() int { + return 2 + + 1*UnixPathMax +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SockAddrUnix) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Family)) + dst = dst[2:] + for idx := 0; idx < UnixPathMax; idx++ { + dst[0] = byte(s.Path[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SockAddrUnix) UnmarshalBytes(src []byte) { + s.Family = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + for idx := 0; idx < UnixPathMax; idx++ { + s.Path[idx] = int8(src[0]) + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SockAddrUnix) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SockAddrUnix) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SockAddrUnix) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SockAddrUnix) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SockAddrUnix) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SockAddrUnix) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SockAddrUnix) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *TCPInfo) SizeBytes() int { + return 224 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *TCPInfo) MarshalBytes(dst []byte) { + dst[0] = byte(t.State) + dst = dst[1:] + dst[0] = byte(t.CaState) + dst = dst[1:] + dst[0] = byte(t.Retransmits) + dst = dst[1:] + dst[0] = byte(t.Probes) + dst = dst[1:] + dst[0] = byte(t.Backoff) + dst = dst[1:] + dst[0] = byte(t.Options) + dst = dst[1:] + dst[0] = byte(t.WindowScale) + dst = dst[1:] + dst[0] = byte(t.DeliveryRateAppLimited) + dst = dst[1:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RTO)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.ATO)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.SndMss)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RcvMss)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Unacked)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Sacked)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Lost)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Retrans)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Fackets)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.LastDataSent)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.LastAckSent)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.LastDataRecv)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.LastAckRecv)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.PMTU)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RcvSsthresh)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RTT)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RTTVar)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.SndSsthresh)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.SndCwnd)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Advmss)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Reordering)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RcvRTT)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.RcvSpace)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.TotalRetrans)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.PacingRate)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.MaxPacingRate)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.BytesAcked)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.BytesReceived)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.SegsOut)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.SegsIn)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.NotSentBytes)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.MinRTT)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.DataSegsIn)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.DataSegsOut)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.DeliveryRate)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.BusyTime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.RwndLimited)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.SndBufLimited)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.Delivered)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.DeliveredCE)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.BytesSent)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(t.BytesRetrans)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.DSACKDups)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.ReordSeen)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *TCPInfo) UnmarshalBytes(src []byte) { + t.State = uint8(src[0]) + src = src[1:] + t.CaState = uint8(src[0]) + src = src[1:] + t.Retransmits = uint8(src[0]) + src = src[1:] + t.Probes = uint8(src[0]) + src = src[1:] + t.Backoff = uint8(src[0]) + src = src[1:] + t.Options = uint8(src[0]) + src = src[1:] + t.WindowScale = uint8(src[0]) + src = src[1:] + t.DeliveryRateAppLimited = uint8(src[0]) + src = src[1:] + t.RTO = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.ATO = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.SndMss = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.RcvMss = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Unacked = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Sacked = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Lost = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Retrans = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Fackets = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.LastDataSent = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.LastAckSent = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.LastDataRecv = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.LastAckRecv = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.PMTU = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.RcvSsthresh = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.RTT = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.RTTVar = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.SndSsthresh = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.SndCwnd = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Advmss = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.Reordering = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.RcvRTT = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.RcvSpace = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.TotalRetrans = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.PacingRate = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.MaxPacingRate = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.BytesAcked = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.BytesReceived = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.SegsOut = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.SegsIn = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.NotSentBytes = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.MinRTT = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.DataSegsIn = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.DataSegsOut = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.DeliveryRate = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.BusyTime = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.RwndLimited = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.SndBufLimited = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.Delivered = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.DeliveredCE = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.BytesSent = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.BytesRetrans = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + t.DSACKDups = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.ReordSeen = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (t *TCPInfo) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *TCPInfo) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(t), uintptr(t.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *TCPInfo) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(t), unsafe.Pointer(&src[0]), uintptr(t.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (t *TCPInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (t *TCPInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return t.CopyOutN(cc, addr, t.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (t *TCPInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *TCPInfo) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (c *ClockT) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (c *ClockT) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(*c)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (c *ClockT) UnmarshalBytes(src []byte) { + *c = ClockT(int64(hostarch.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (c *ClockT) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (c *ClockT) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(c), uintptr(c.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (c *ClockT) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(c), unsafe.Pointer(&src[0]), uintptr(c.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (c *ClockT) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (c *ClockT) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return c.CopyOutN(cc, addr, c.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (c *ClockT) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (c *ClockT) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(c))) + hdr.Len = c.SizeBytes() + hdr.Cap = c.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that c + // must live until the use above. + runtime.KeepAlive(c) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *ItimerVal) SizeBytes() int { + return 0 + + (*Timeval)(nil).SizeBytes() + + (*Timeval)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *ItimerVal) MarshalBytes(dst []byte) { + i.Interval.MarshalBytes(dst[:i.Interval.SizeBytes()]) + dst = dst[i.Interval.SizeBytes():] + i.Value.MarshalBytes(dst[:i.Value.SizeBytes()]) + dst = dst[i.Value.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *ItimerVal) UnmarshalBytes(src []byte) { + i.Interval.UnmarshalBytes(src[:i.Interval.SizeBytes()]) + src = src[i.Interval.SizeBytes():] + i.Value.UnmarshalBytes(src[:i.Value.SizeBytes()]) + src = src[i.Value.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *ItimerVal) Packed() bool { + return i.Interval.Packed() && i.Value.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *ItimerVal) MarshalUnsafe(dst []byte) { + if i.Interval.Packed() && i.Value.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type ItimerVal doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *ItimerVal) UnmarshalUnsafe(src []byte) { + if i.Interval.Packed() && i.Value.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type ItimerVal doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *ItimerVal) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Interval.Packed() && i.Value.Packed() { + // Type ItimerVal doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *ItimerVal) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *ItimerVal) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Interval.Packed() && i.Value.Packed() { + // Type ItimerVal doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *ItimerVal) WriteTo(writer io.Writer) (int64, error) { + if !i.Interval.Packed() && i.Value.Packed() { + // Type ItimerVal doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (i *Itimerspec) SizeBytes() int { + return 0 + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Itimerspec) MarshalBytes(dst []byte) { + i.Interval.MarshalBytes(dst[:i.Interval.SizeBytes()]) + dst = dst[i.Interval.SizeBytes():] + i.Value.MarshalBytes(dst[:i.Value.SizeBytes()]) + dst = dst[i.Value.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Itimerspec) UnmarshalBytes(src []byte) { + i.Interval.UnmarshalBytes(src[:i.Interval.SizeBytes()]) + src = src[i.Interval.SizeBytes():] + i.Value.UnmarshalBytes(src[:i.Value.SizeBytes()]) + src = src[i.Value.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Itimerspec) Packed() bool { + return i.Interval.Packed() && i.Value.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Itimerspec) MarshalUnsafe(dst []byte) { + if i.Interval.Packed() && i.Value.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) + } else { + // Type Itimerspec doesn't have a packed layout in memory, fallback to MarshalBytes. + i.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Itimerspec) UnmarshalUnsafe(src []byte) { + if i.Interval.Packed() && i.Value.Packed() { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) + } else { + // Type Itimerspec doesn't have a packed layout in memory, fallback to UnmarshalBytes. + i.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Itimerspec) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !i.Interval.Packed() && i.Value.Packed() { + // Type Itimerspec doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + i.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Itimerspec) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Itimerspec) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !i.Interval.Packed() && i.Value.Packed() { + // Type Itimerspec doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(i.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + i.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Itimerspec) WriteTo(writer io.Writer) (int64, error) { + if !i.Interval.Packed() && i.Value.Packed() { + // Type Itimerspec doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, i.SizeBytes()) + i.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (sxts *StatxTimestamp) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (sxts *StatxTimestamp) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(sxts.Sec)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(sxts.Nsec)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (sxts *StatxTimestamp) UnmarshalBytes(src []byte) { + sxts.Sec = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + sxts.Nsec = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (sxts *StatxTimestamp) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (sxts *StatxTimestamp) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(sxts), uintptr(sxts.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (sxts *StatxTimestamp) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(sxts), unsafe.Pointer(&src[0]), uintptr(sxts.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (sxts *StatxTimestamp) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(sxts))) + hdr.Len = sxts.SizeBytes() + hdr.Cap = sxts.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that sxts + // must live until the use above. + runtime.KeepAlive(sxts) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (sxts *StatxTimestamp) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return sxts.CopyOutN(cc, addr, sxts.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (sxts *StatxTimestamp) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(sxts))) + hdr.Len = sxts.SizeBytes() + hdr.Cap = sxts.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that sxts + // must live until the use above. + runtime.KeepAlive(sxts) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (sxts *StatxTimestamp) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(sxts))) + hdr.Len = sxts.SizeBytes() + hdr.Cap = sxts.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that sxts + // must live until the use above. + runtime.KeepAlive(sxts) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (t *TimeT) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *TimeT) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(*t)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *TimeT) UnmarshalBytes(src []byte) { + *t = TimeT(int64(hostarch.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (t *TimeT) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *TimeT) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(t), uintptr(t.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *TimeT) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(t), unsafe.Pointer(&src[0]), uintptr(t.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (t *TimeT) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (t *TimeT) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return t.CopyOutN(cc, addr, t.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (t *TimeT) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *TimeT) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (t *TimerID) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *TimerID) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*t)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *TimerID) UnmarshalBytes(src []byte) { + *t = TimerID(int32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (t *TimerID) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *TimerID) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(t), uintptr(t.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *TimerID) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(t), unsafe.Pointer(&src[0]), uintptr(t.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (t *TimerID) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (t *TimerID) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return t.CopyOutN(cc, addr, t.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (t *TimerID) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *TimerID) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ts *Timespec) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ts *Timespec) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(ts.Sec)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(ts.Nsec)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ts *Timespec) UnmarshalBytes(src []byte) { + ts.Sec = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + ts.Nsec = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (ts *Timespec) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ts *Timespec) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(ts), uintptr(ts.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ts *Timespec) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(ts), unsafe.Pointer(&src[0]), uintptr(ts.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (ts *Timespec) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(ts))) + hdr.Len = ts.SizeBytes() + hdr.Cap = ts.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that ts + // must live until the use above. + runtime.KeepAlive(ts) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (ts *Timespec) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return ts.CopyOutN(cc, addr, ts.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (ts *Timespec) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(ts))) + hdr.Len = ts.SizeBytes() + hdr.Cap = ts.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that ts + // must live until the use above. + runtime.KeepAlive(ts) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ts *Timespec) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(ts))) + hdr.Len = ts.SizeBytes() + hdr.Cap = ts.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that ts + // must live until the use above. + runtime.KeepAlive(ts) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyTimespecSliceIn copies in a slice of Timespec objects from the task's memory. +func CopyTimespecSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Timespec) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Timespec)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyTimespecSliceOut copies a slice of Timespec objects to the task's memory. +func CopyTimespecSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []Timespec) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Timespec)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeTimespecSlice is like Timespec.MarshalUnsafe, but for a []Timespec. +func MarshalUnsafeTimespecSlice(src []Timespec, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Timespec)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeTimespecSlice is like Timespec.UnmarshalUnsafe, but for a []Timespec. +func UnmarshalUnsafeTimespecSlice(dst []Timespec, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Timespec)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (tv *Timeval) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (tv *Timeval) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(tv.Sec)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(tv.Usec)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (tv *Timeval) UnmarshalBytes(src []byte) { + tv.Sec = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + tv.Usec = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (tv *Timeval) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (tv *Timeval) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(tv), uintptr(tv.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (tv *Timeval) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(tv), unsafe.Pointer(&src[0]), uintptr(tv.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (tv *Timeval) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tv))) + hdr.Len = tv.SizeBytes() + hdr.Cap = tv.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that tv + // must live until the use above. + runtime.KeepAlive(tv) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (tv *Timeval) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return tv.CopyOutN(cc, addr, tv.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (tv *Timeval) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tv))) + hdr.Len = tv.SizeBytes() + hdr.Cap = tv.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that tv + // must live until the use above. + runtime.KeepAlive(tv) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (tv *Timeval) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tv))) + hdr.Len = tv.SizeBytes() + hdr.Cap = tv.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that tv + // must live until the use above. + runtime.KeepAlive(tv) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyTimevalSliceIn copies in a slice of Timeval objects from the task's memory. +func CopyTimevalSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Timeval) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Timeval)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyTimevalSliceOut copies a slice of Timeval objects to the task's memory. +func CopyTimevalSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []Timeval) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Timeval)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeTimevalSlice is like Timeval.MarshalUnsafe, but for a []Timeval. +func MarshalUnsafeTimevalSlice(src []Timeval, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Timeval)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeTimevalSlice is like Timeval.UnmarshalUnsafe, but for a []Timeval. +func UnmarshalUnsafeTimevalSlice(dst []Timeval, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Timeval)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Tms) SizeBytes() int { + return 0 + + (*ClockT)(nil).SizeBytes() + + (*ClockT)(nil).SizeBytes() + + (*ClockT)(nil).SizeBytes() + + (*ClockT)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Tms) MarshalBytes(dst []byte) { + t.UTime.MarshalBytes(dst[:t.UTime.SizeBytes()]) + dst = dst[t.UTime.SizeBytes():] + t.STime.MarshalBytes(dst[:t.STime.SizeBytes()]) + dst = dst[t.STime.SizeBytes():] + t.CUTime.MarshalBytes(dst[:t.CUTime.SizeBytes()]) + dst = dst[t.CUTime.SizeBytes():] + t.CSTime.MarshalBytes(dst[:t.CSTime.SizeBytes()]) + dst = dst[t.CSTime.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Tms) UnmarshalBytes(src []byte) { + t.UTime.UnmarshalBytes(src[:t.UTime.SizeBytes()]) + src = src[t.UTime.SizeBytes():] + t.STime.UnmarshalBytes(src[:t.STime.SizeBytes()]) + src = src[t.STime.SizeBytes():] + t.CUTime.UnmarshalBytes(src[:t.CUTime.SizeBytes()]) + src = src[t.CUTime.SizeBytes():] + t.CSTime.UnmarshalBytes(src[:t.CSTime.SizeBytes()]) + src = src[t.CSTime.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (t *Tms) Packed() bool { + return t.CSTime.Packed() && t.CUTime.Packed() && t.STime.Packed() && t.UTime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *Tms) MarshalUnsafe(dst []byte) { + if t.CSTime.Packed() && t.CUTime.Packed() && t.STime.Packed() && t.UTime.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(t), uintptr(t.SizeBytes())) + } else { + // Type Tms doesn't have a packed layout in memory, fallback to MarshalBytes. + t.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *Tms) UnmarshalUnsafe(src []byte) { + if t.CSTime.Packed() && t.CUTime.Packed() && t.STime.Packed() && t.UTime.Packed() { + gohacks.Memmove(unsafe.Pointer(t), unsafe.Pointer(&src[0]), uintptr(t.SizeBytes())) + } else { + // Type Tms doesn't have a packed layout in memory, fallback to UnmarshalBytes. + t.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (t *Tms) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !t.CSTime.Packed() && t.CUTime.Packed() && t.STime.Packed() && t.UTime.Packed() { + // Type Tms doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(t.SizeBytes()) // escapes: okay. + t.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (t *Tms) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return t.CopyOutN(cc, addr, t.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (t *Tms) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !t.CSTime.Packed() && t.CUTime.Packed() && t.STime.Packed() && t.UTime.Packed() { + // Type Tms doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(t.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + t.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *Tms) WriteTo(writer io.Writer) (int64, error) { + if !t.CSTime.Packed() && t.CUTime.Packed() && t.STime.Packed() && t.UTime.Packed() { + // Type Tms doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, t.SizeBytes()) + t.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *Utime) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *Utime) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Actime)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Modtime)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *Utime) UnmarshalBytes(src []byte) { + u.Actime = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + u.Modtime = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *Utime) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *Utime) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *Utime) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *Utime) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *Utime) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *Utime) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *Utime) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Termios) SizeBytes() int { + return 17 + + 1*NumControlCharacters +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Termios) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.InputFlags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.OutputFlags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.ControlFlags)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(t.LocalFlags)) + dst = dst[4:] + dst[0] = byte(t.LineDiscipline) + dst = dst[1:] + for idx := 0; idx < NumControlCharacters; idx++ { + dst[0] = byte(t.ControlCharacters[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Termios) UnmarshalBytes(src []byte) { + t.InputFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.OutputFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.ControlFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.LocalFlags = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + t.LineDiscipline = uint8(src[0]) + src = src[1:] + for idx := 0; idx < NumControlCharacters; idx++ { + t.ControlCharacters[idx] = uint8(src[0]) + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (t *Termios) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *Termios) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(t), uintptr(t.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *Termios) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(t), unsafe.Pointer(&src[0]), uintptr(t.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (t *Termios) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (t *Termios) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return t.CopyOutN(cc, addr, t.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (t *Termios) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *Termios) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(t))) + hdr.Len = t.SizeBytes() + hdr.Cap = t.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that t + // must live until the use above. + runtime.KeepAlive(t) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *WindowSize) SizeBytes() int { + return 4 + + 1*4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *WindowSize) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(w.Rows)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(w.Cols)) + dst = dst[2:] + // Padding: dst[:sizeof(byte)*4] ~= [4]byte{0} + dst = dst[1*(4):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *WindowSize) UnmarshalBytes(src []byte) { + w.Rows = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + w.Cols = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: ~ copy([4]byte(w._), src[:sizeof(byte)*4]) + src = src[1*(4):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (w *WindowSize) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (w *WindowSize) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(w), uintptr(w.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (w *WindowSize) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(w), unsafe.Pointer(&src[0]), uintptr(w.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (w *WindowSize) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(w))) + hdr.Len = w.SizeBytes() + hdr.Cap = w.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that w + // must live until the use above. + runtime.KeepAlive(w) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (w *WindowSize) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return w.CopyOutN(cc, addr, w.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (w *WindowSize) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(w))) + hdr.Len = w.SizeBytes() + hdr.Cap = w.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that w + // must live until the use above. + runtime.KeepAlive(w) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (w *WindowSize) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(w))) + hdr.Len = w.SizeBytes() + hdr.Cap = w.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that w + // must live until the use above. + runtime.KeepAlive(w) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (w *Winsize) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (w *Winsize) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(w.Row)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(w.Col)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(w.Xpixel)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(w.Ypixel)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (w *Winsize) UnmarshalBytes(src []byte) { + w.Row = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + w.Col = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + w.Xpixel = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + w.Ypixel = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (w *Winsize) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (w *Winsize) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(w), uintptr(w.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (w *Winsize) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(w), unsafe.Pointer(&src[0]), uintptr(w.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (w *Winsize) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(w))) + hdr.Len = w.SizeBytes() + hdr.Cap = w.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that w + // must live until the use above. + runtime.KeepAlive(w) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (w *Winsize) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return w.CopyOutN(cc, addr, w.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (w *Winsize) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(w))) + hdr.Len = w.SizeBytes() + hdr.Cap = w.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that w + // must live until the use above. + runtime.KeepAlive(w) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (w *Winsize) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(w))) + hdr.Len = w.SizeBytes() + hdr.Cap = w.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that w + // must live until the use above. + runtime.KeepAlive(w) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *UtsName) SizeBytes() int { + return 0 + + 1*(UTSLen+1) + + 1*(UTSLen+1) + + 1*(UTSLen+1) + + 1*(UTSLen+1) + + 1*(UTSLen+1) + + 1*(UTSLen+1) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *UtsName) MarshalBytes(dst []byte) { + for idx := 0; idx < (UTSLen+1); idx++ { + dst[0] = byte(u.Sysname[idx]) + dst = dst[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + dst[0] = byte(u.Nodename[idx]) + dst = dst[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + dst[0] = byte(u.Release[idx]) + dst = dst[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + dst[0] = byte(u.Version[idx]) + dst = dst[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + dst[0] = byte(u.Machine[idx]) + dst = dst[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + dst[0] = byte(u.Domainname[idx]) + dst = dst[1:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *UtsName) UnmarshalBytes(src []byte) { + for idx := 0; idx < (UTSLen+1); idx++ { + u.Sysname[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + u.Nodename[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + u.Release[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + u.Version[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + u.Machine[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < (UTSLen+1); idx++ { + u.Domainname[idx] = src[0] + src = src[1:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *UtsName) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *UtsName) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *UtsName) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *UtsName) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *UtsName) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *UtsName) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *UtsName) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/abi/linux/linux_amd64_abi_autogen_unsafe.go b/pkg/abi/linux/linux_amd64_abi_autogen_unsafe.go new file mode 100644 index 000000000..770b19c2f --- /dev/null +++ b/pkg/abi/linux/linux_amd64_abi_autogen_unsafe.go @@ -0,0 +1,738 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build amd64 && amd64 && amd64 && amd64 && amd64 +// +build amd64,amd64,amd64,amd64,amd64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*EpollEvent)(nil) +var _ marshal.Marshallable = (*IPCPerm)(nil) +var _ marshal.Marshallable = (*PtraceRegs)(nil) +var _ marshal.Marshallable = (*SemidDS)(nil) +var _ marshal.Marshallable = (*Stat)(nil) +var _ marshal.Marshallable = (*TimeT)(nil) +var _ marshal.Marshallable = (*Timespec)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *EpollEvent) SizeBytes() int { + return 4 + + 4*2 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *EpollEvent) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Events)) + dst = dst[4:] + for idx := 0; idx < 2; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Data[idx])) + dst = dst[4:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *EpollEvent) UnmarshalBytes(src []byte) { + e.Events = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 2; idx++ { + e.Data[idx] = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (e *EpollEvent) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *EpollEvent) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(e), uintptr(e.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *EpollEvent) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(e), unsafe.Pointer(&src[0]), uintptr(e.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (e *EpollEvent) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (e *EpollEvent) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return e.CopyOutN(cc, addr, e.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (e *EpollEvent) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *EpollEvent) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyEpollEventSliceIn copies in a slice of EpollEvent objects from the task's memory. +func CopyEpollEventSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []EpollEvent) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyEpollEventSliceOut copies a slice of EpollEvent objects to the task's memory. +func CopyEpollEventSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []EpollEvent) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeEpollEventSlice is like EpollEvent.MarshalUnsafe, but for a []EpollEvent. +func MarshalUnsafeEpollEventSlice(src []EpollEvent, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeEpollEventSlice is like EpollEvent.UnmarshalUnsafe, but for a []EpollEvent. +func UnmarshalUnsafeEpollEventSlice(dst []EpollEvent, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Stat) SizeBytes() int { + return 72 + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() + + 8*3 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Stat) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Dev)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Ino)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Nlink)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Mode)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.GID)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rdev)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Blksize)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + s.ATime.MarshalBytes(dst[:s.ATime.SizeBytes()]) + dst = dst[s.ATime.SizeBytes():] + s.MTime.MarshalBytes(dst[:s.MTime.SizeBytes()]) + dst = dst[s.MTime.SizeBytes():] + s.CTime.MarshalBytes(dst[:s.CTime.SizeBytes()]) + dst = dst[s.CTime.SizeBytes():] + // Padding: dst[:sizeof(int64)*3] ~= [3]int64{0} + dst = dst[8*(3):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Stat) UnmarshalBytes(src []byte) { + s.Dev = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Ino = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Nlink = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] + s.Rdev = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Size = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blksize = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blocks = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ATime.UnmarshalBytes(src[:s.ATime.SizeBytes()]) + src = src[s.ATime.SizeBytes():] + s.MTime.UnmarshalBytes(src[:s.MTime.SizeBytes()]) + src = src[s.MTime.SizeBytes():] + s.CTime.UnmarshalBytes(src[:s.CTime.SizeBytes()]) + src = src[s.CTime.SizeBytes():] + // Padding: ~ copy([3]int64(s._), src[:sizeof(int64)*3]) + src = src[8*(3):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Stat) Packed() bool { + return s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Stat) MarshalUnsafe(dst []byte) { + if s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type Stat doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Stat) UnmarshalUnsafe(src []byte) { + if s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type Stat doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Stat) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Stat) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Stat) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Stat) WriteTo(writer io.Writer) (int64, error) { + if !s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (p *PtraceRegs) SizeBytes() int { + return 216 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (p *PtraceRegs) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R15)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R14)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R13)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R12)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rbp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rbx)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R11)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R10)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R9)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.R8)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rax)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rcx)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rdx)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rsi)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rdi)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Orig_rax)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rip)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Cs)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Eflags)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Rsp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Ss)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Fs_base)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Gs_base)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Ds)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Es)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Fs)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Gs)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (p *PtraceRegs) UnmarshalBytes(src []byte) { + p.R15 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R14 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R13 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R12 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rbp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rbx = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R11 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R10 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R9 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.R8 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rax = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rcx = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rdx = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rsi = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rdi = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Orig_rax = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rip = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Cs = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Eflags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Rsp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Ss = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Fs_base = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Gs_base = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Ds = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Es = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Fs = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Gs = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (p *PtraceRegs) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (p *PtraceRegs) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(p), uintptr(p.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (p *PtraceRegs) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(p), unsafe.Pointer(&src[0]), uintptr(p.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (p *PtraceRegs) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (p *PtraceRegs) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return p.CopyOutN(cc, addr, p.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (p *PtraceRegs) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (p *PtraceRegs) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SemidDS) SizeBytes() int { + return 40 + + (*IPCPerm)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SemidDS) MarshalBytes(dst []byte) { + s.SemPerm.MarshalBytes(dst[:s.SemPerm.SizeBytes()]) + dst = dst[s.SemPerm.SizeBytes():] + s.SemOTime.MarshalBytes(dst[:s.SemOTime.SizeBytes()]) + dst = dst[s.SemOTime.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.unused1)) + dst = dst[8:] + s.SemCTime.MarshalBytes(dst[:s.SemCTime.SizeBytes()]) + dst = dst[s.SemCTime.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.unused2)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.SemNSems)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.unused3)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.unused4)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SemidDS) UnmarshalBytes(src []byte) { + s.SemPerm.UnmarshalBytes(src[:s.SemPerm.SizeBytes()]) + src = src[s.SemPerm.SizeBytes():] + s.SemOTime.UnmarshalBytes(src[:s.SemOTime.SizeBytes()]) + src = src[s.SemOTime.SizeBytes():] + s.unused1 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.SemCTime.UnmarshalBytes(src[:s.SemCTime.SizeBytes()]) + src = src[s.SemCTime.SizeBytes():] + s.unused2 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.SemNSems = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.unused3 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.unused4 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SemidDS) Packed() bool { + return s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SemidDS) MarshalUnsafe(dst []byte) { + if s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SemidDS doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SemidDS) UnmarshalUnsafe(src []byte) { + if s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SemidDS doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SemidDS) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + // Type SemidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SemidDS) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SemidDS) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + // Type SemidDS doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SemidDS) WriteTo(writer io.Writer) (int64, error) { + if !s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + // Type SemidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/abi/linux/linux_amd64_state_autogen.go b/pkg/abi/linux/linux_amd64_state_autogen.go new file mode 100644 index 000000000..b824c1dd9 --- /dev/null +++ b/pkg/abi/linux/linux_amd64_state_autogen.go @@ -0,0 +1,117 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && amd64 && amd64 && amd64 +// +build amd64,amd64,amd64,amd64,amd64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *PtraceRegs) StateTypeName() string { + return "pkg/abi/linux.PtraceRegs" +} + +func (p *PtraceRegs) StateFields() []string { + return []string{ + "R15", + "R14", + "R13", + "R12", + "Rbp", + "Rbx", + "R11", + "R10", + "R9", + "R8", + "Rax", + "Rcx", + "Rdx", + "Rsi", + "Rdi", + "Orig_rax", + "Rip", + "Cs", + "Eflags", + "Rsp", + "Ss", + "Fs_base", + "Gs_base", + "Ds", + "Es", + "Fs", + "Gs", + } +} + +func (p *PtraceRegs) beforeSave() {} + +// +checklocksignore +func (p *PtraceRegs) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.R15) + stateSinkObject.Save(1, &p.R14) + stateSinkObject.Save(2, &p.R13) + stateSinkObject.Save(3, &p.R12) + stateSinkObject.Save(4, &p.Rbp) + stateSinkObject.Save(5, &p.Rbx) + stateSinkObject.Save(6, &p.R11) + stateSinkObject.Save(7, &p.R10) + stateSinkObject.Save(8, &p.R9) + stateSinkObject.Save(9, &p.R8) + stateSinkObject.Save(10, &p.Rax) + stateSinkObject.Save(11, &p.Rcx) + stateSinkObject.Save(12, &p.Rdx) + stateSinkObject.Save(13, &p.Rsi) + stateSinkObject.Save(14, &p.Rdi) + stateSinkObject.Save(15, &p.Orig_rax) + stateSinkObject.Save(16, &p.Rip) + stateSinkObject.Save(17, &p.Cs) + stateSinkObject.Save(18, &p.Eflags) + stateSinkObject.Save(19, &p.Rsp) + stateSinkObject.Save(20, &p.Ss) + stateSinkObject.Save(21, &p.Fs_base) + stateSinkObject.Save(22, &p.Gs_base) + stateSinkObject.Save(23, &p.Ds) + stateSinkObject.Save(24, &p.Es) + stateSinkObject.Save(25, &p.Fs) + stateSinkObject.Save(26, &p.Gs) +} + +func (p *PtraceRegs) afterLoad() {} + +// +checklocksignore +func (p *PtraceRegs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.R15) + stateSourceObject.Load(1, &p.R14) + stateSourceObject.Load(2, &p.R13) + stateSourceObject.Load(3, &p.R12) + stateSourceObject.Load(4, &p.Rbp) + stateSourceObject.Load(5, &p.Rbx) + stateSourceObject.Load(6, &p.R11) + stateSourceObject.Load(7, &p.R10) + stateSourceObject.Load(8, &p.R9) + stateSourceObject.Load(9, &p.R8) + stateSourceObject.Load(10, &p.Rax) + stateSourceObject.Load(11, &p.Rcx) + stateSourceObject.Load(12, &p.Rdx) + stateSourceObject.Load(13, &p.Rsi) + stateSourceObject.Load(14, &p.Rdi) + stateSourceObject.Load(15, &p.Orig_rax) + stateSourceObject.Load(16, &p.Rip) + stateSourceObject.Load(17, &p.Cs) + stateSourceObject.Load(18, &p.Eflags) + stateSourceObject.Load(19, &p.Rsp) + stateSourceObject.Load(20, &p.Ss) + stateSourceObject.Load(21, &p.Fs_base) + stateSourceObject.Load(22, &p.Gs_base) + stateSourceObject.Load(23, &p.Ds) + stateSourceObject.Load(24, &p.Es) + stateSourceObject.Load(25, &p.Fs) + stateSourceObject.Load(26, &p.Gs) +} + +func init() { + state.Register((*PtraceRegs)(nil)) +} diff --git a/pkg/abi/linux/linux_arm64_abi_autogen_unsafe.go b/pkg/abi/linux/linux_arm64_abi_autogen_unsafe.go new file mode 100644 index 000000000..06c2aec98 --- /dev/null +++ b/pkg/abi/linux/linux_arm64_abi_autogen_unsafe.go @@ -0,0 +1,651 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build arm64 && arm64 && arm64 && arm64 +// +build arm64,arm64,arm64,arm64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*EpollEvent)(nil) +var _ marshal.Marshallable = (*IPCPerm)(nil) +var _ marshal.Marshallable = (*PtraceRegs)(nil) +var _ marshal.Marshallable = (*SemidDS)(nil) +var _ marshal.Marshallable = (*Stat)(nil) +var _ marshal.Marshallable = (*TimeT)(nil) +var _ marshal.Marshallable = (*Timespec)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (e *EpollEvent) SizeBytes() int { + return 8 + + 4*2 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (e *EpollEvent) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Events)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] + for idx := 0; idx < 2; idx++ { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(e.Data[idx])) + dst = dst[4:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (e *EpollEvent) UnmarshalBytes(src []byte) { + e.Events = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] + for idx := 0; idx < 2; idx++ { + e.Data[idx] = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (e *EpollEvent) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (e *EpollEvent) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(e), uintptr(e.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (e *EpollEvent) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(e), unsafe.Pointer(&src[0]), uintptr(e.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (e *EpollEvent) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (e *EpollEvent) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return e.CopyOutN(cc, addr, e.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (e *EpollEvent) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (e *EpollEvent) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(e))) + hdr.Len = e.SizeBytes() + hdr.Cap = e.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that e + // must live until the use above. + runtime.KeepAlive(e) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyEpollEventSliceIn copies in a slice of EpollEvent objects from the task's memory. +func CopyEpollEventSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []EpollEvent) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyEpollEventSliceOut copies a slice of EpollEvent objects to the task's memory. +func CopyEpollEventSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []EpollEvent) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeEpollEventSlice is like EpollEvent.MarshalUnsafe, but for a []EpollEvent. +func MarshalUnsafeEpollEventSlice(src []EpollEvent, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size * count, nil +} + +// UnmarshalUnsafeEpollEventSlice is like EpollEvent.UnmarshalUnsafe, but for a []EpollEvent. +func UnmarshalUnsafeEpollEventSlice(dst []EpollEvent, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*EpollEvent)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return count*size, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *Stat) SizeBytes() int { + return 72 + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() + + (*Timespec)(nil).SizeBytes() + + 4*2 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *Stat) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Dev)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Ino)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Mode)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Nlink)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.UID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.GID)) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rdev)) + dst = dst[8:] + // Padding: dst[:sizeof(uint64)] ~= uint64(0) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Size)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.Blksize)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Blocks)) + dst = dst[8:] + s.ATime.MarshalBytes(dst[:s.ATime.SizeBytes()]) + dst = dst[s.ATime.SizeBytes():] + s.MTime.MarshalBytes(dst[:s.MTime.SizeBytes()]) + dst = dst[s.MTime.SizeBytes():] + s.CTime.MarshalBytes(dst[:s.CTime.SizeBytes()]) + dst = dst[s.CTime.SizeBytes():] + // Padding: dst[:sizeof(int32)*2] ~= [2]int32{0} + dst = dst[4*(2):] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *Stat) UnmarshalBytes(src []byte) { + s.Dev = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Ino = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Mode = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Nlink = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.UID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.GID = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + s.Rdev = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + // Padding: var _ uint64 ~= src[:sizeof(uint64)] + src = src[8:] + s.Size = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Blksize = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] + s.Blocks = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.ATime.UnmarshalBytes(src[:s.ATime.SizeBytes()]) + src = src[s.ATime.SizeBytes():] + s.MTime.UnmarshalBytes(src[:s.MTime.SizeBytes()]) + src = src[s.MTime.SizeBytes():] + s.CTime.UnmarshalBytes(src[:s.CTime.SizeBytes()]) + src = src[s.CTime.SizeBytes():] + // Padding: ~ copy([2]int32(s._), src[:sizeof(int32)*2]) + src = src[4*(2):] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *Stat) Packed() bool { + return s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *Stat) MarshalUnsafe(dst []byte) { + if s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type Stat doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *Stat) UnmarshalUnsafe(src []byte) { + if s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type Stat doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *Stat) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *Stat) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *Stat) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *Stat) WriteTo(writer io.Writer) (int64, error) { + if !s.ATime.Packed() && s.CTime.Packed() && s.MTime.Packed() { + // Type Stat doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (p *PtraceRegs) SizeBytes() int { + return 24 + + 8*31 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (p *PtraceRegs) MarshalBytes(dst []byte) { + for idx := 0; idx < 31; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Regs[idx])) + dst = dst[8:] + } + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Sp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Pc)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(p.Pstate)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (p *PtraceRegs) UnmarshalBytes(src []byte) { + for idx := 0; idx < 31; idx++ { + p.Regs[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } + p.Sp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Pc = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + p.Pstate = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (p *PtraceRegs) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (p *PtraceRegs) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(p), uintptr(p.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (p *PtraceRegs) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(p), unsafe.Pointer(&src[0]), uintptr(p.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (p *PtraceRegs) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (p *PtraceRegs) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return p.CopyOutN(cc, addr, p.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (p *PtraceRegs) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (p *PtraceRegs) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(p))) + hdr.Len = p.SizeBytes() + hdr.Cap = p.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that p + // must live until the use above. + runtime.KeepAlive(p) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SemidDS) SizeBytes() int { + return 24 + + (*IPCPerm)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() + + (*TimeT)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SemidDS) MarshalBytes(dst []byte) { + s.SemPerm.MarshalBytes(dst[:s.SemPerm.SizeBytes()]) + dst = dst[s.SemPerm.SizeBytes():] + s.SemOTime.MarshalBytes(dst[:s.SemOTime.SizeBytes()]) + dst = dst[s.SemOTime.SizeBytes():] + s.SemCTime.MarshalBytes(dst[:s.SemCTime.SizeBytes()]) + dst = dst[s.SemCTime.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.SemNSems)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.unused3)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.unused4)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SemidDS) UnmarshalBytes(src []byte) { + s.SemPerm.UnmarshalBytes(src[:s.SemPerm.SizeBytes()]) + src = src[s.SemPerm.SizeBytes():] + s.SemOTime.UnmarshalBytes(src[:s.SemOTime.SizeBytes()]) + src = src[s.SemOTime.SizeBytes():] + s.SemCTime.UnmarshalBytes(src[:s.SemCTime.SizeBytes()]) + src = src[s.SemCTime.SizeBytes():] + s.SemNSems = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.unused3 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.unused4 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SemidDS) Packed() bool { + return s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SemidDS) MarshalUnsafe(dst []byte) { + if s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SemidDS doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SemidDS) UnmarshalUnsafe(src []byte) { + if s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SemidDS doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SemidDS) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + // Type SemidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SemidDS) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SemidDS) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + // Type SemidDS doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SemidDS) WriteTo(writer io.Writer) (int64, error) { + if !s.SemCTime.Packed() && s.SemOTime.Packed() && s.SemPerm.Packed() { + // Type SemidDS doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/abi/linux/linux_arm64_state_autogen.go b/pkg/abi/linux/linux_arm64_state_autogen.go new file mode 100644 index 000000000..1994658e5 --- /dev/null +++ b/pkg/abi/linux/linux_arm64_state_autogen.go @@ -0,0 +1,48 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 && arm64 +// +build arm64,arm64,arm64,arm64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *PtraceRegs) StateTypeName() string { + return "pkg/abi/linux.PtraceRegs" +} + +func (p *PtraceRegs) StateFields() []string { + return []string{ + "Regs", + "Sp", + "Pc", + "Pstate", + } +} + +func (p *PtraceRegs) beforeSave() {} + +// +checklocksignore +func (p *PtraceRegs) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.Regs) + stateSinkObject.Save(1, &p.Sp) + stateSinkObject.Save(2, &p.Pc) + stateSinkObject.Save(3, &p.Pstate) +} + +func (p *PtraceRegs) afterLoad() {} + +// +checklocksignore +func (p *PtraceRegs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.Regs) + stateSourceObject.Load(1, &p.Sp) + stateSourceObject.Load(2, &p.Pc) + stateSourceObject.Load(3, &p.Pstate) +} + +func init() { + state.Register((*PtraceRegs)(nil)) +} diff --git a/pkg/abi/linux/linux_state_autogen.go b/pkg/abi/linux/linux_state_autogen.go new file mode 100644 index 000000000..e2043be15 --- /dev/null +++ b/pkg/abi/linux/linux_state_autogen.go @@ -0,0 +1,290 @@ +// automatically generated by stateify. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *IOEvent) StateTypeName() string { + return "pkg/abi/linux.IOEvent" +} + +func (i *IOEvent) StateFields() []string { + return []string{ + "Data", + "Obj", + "Result", + "Result2", + } +} + +func (i *IOEvent) beforeSave() {} + +// +checklocksignore +func (i *IOEvent) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.Data) + stateSinkObject.Save(1, &i.Obj) + stateSinkObject.Save(2, &i.Result) + stateSinkObject.Save(3, &i.Result2) +} + +func (i *IOEvent) afterLoad() {} + +// +checklocksignore +func (i *IOEvent) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.Data) + stateSourceObject.Load(1, &i.Obj) + stateSourceObject.Load(2, &i.Result) + stateSourceObject.Load(3, &i.Result2) +} + +func (b *BPFInstruction) StateTypeName() string { + return "pkg/abi/linux.BPFInstruction" +} + +func (b *BPFInstruction) StateFields() []string { + return []string{ + "OpCode", + "JumpIfTrue", + "JumpIfFalse", + "K", + } +} + +func (b *BPFInstruction) beforeSave() {} + +// +checklocksignore +func (b *BPFInstruction) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.OpCode) + stateSinkObject.Save(1, &b.JumpIfTrue) + stateSinkObject.Save(2, &b.JumpIfFalse) + stateSinkObject.Save(3, &b.K) +} + +func (b *BPFInstruction) afterLoad() {} + +// +checklocksignore +func (b *BPFInstruction) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.OpCode) + stateSourceObject.Load(1, &b.JumpIfTrue) + stateSourceObject.Load(2, &b.JumpIfFalse) + stateSourceObject.Load(3, &b.K) +} + +func (s *SigAction) StateTypeName() string { + return "pkg/abi/linux.SigAction" +} + +func (s *SigAction) StateFields() []string { + return []string{ + "Handler", + "Flags", + "Restorer", + "Mask", + } +} + +func (s *SigAction) beforeSave() {} + +// +checklocksignore +func (s *SigAction) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Handler) + stateSinkObject.Save(1, &s.Flags) + stateSinkObject.Save(2, &s.Restorer) + stateSinkObject.Save(3, &s.Mask) +} + +func (s *SigAction) afterLoad() {} + +// +checklocksignore +func (s *SigAction) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Handler) + stateSourceObject.Load(1, &s.Flags) + stateSourceObject.Load(2, &s.Restorer) + stateSourceObject.Load(3, &s.Mask) +} + +func (s *SignalStack) StateTypeName() string { + return "pkg/abi/linux.SignalStack" +} + +func (s *SignalStack) StateFields() []string { + return []string{ + "Addr", + "Flags", + "Size", + } +} + +func (s *SignalStack) beforeSave() {} + +// +checklocksignore +func (s *SignalStack) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Addr) + stateSinkObject.Save(1, &s.Flags) + stateSinkObject.Save(2, &s.Size) +} + +func (s *SignalStack) afterLoad() {} + +// +checklocksignore +func (s *SignalStack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Addr) + stateSourceObject.Load(1, &s.Flags) + stateSourceObject.Load(2, &s.Size) +} + +func (s *SignalInfo) StateTypeName() string { + return "pkg/abi/linux.SignalInfo" +} + +func (s *SignalInfo) StateFields() []string { + return []string{ + "Signo", + "Errno", + "Code", + "Fields", + } +} + +func (s *SignalInfo) beforeSave() {} + +// +checklocksignore +func (s *SignalInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Signo) + stateSinkObject.Save(1, &s.Errno) + stateSinkObject.Save(2, &s.Code) + stateSinkObject.Save(3, &s.Fields) +} + +func (s *SignalInfo) afterLoad() {} + +// +checklocksignore +func (s *SignalInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Signo) + stateSourceObject.Load(1, &s.Errno) + stateSourceObject.Load(2, &s.Code) + stateSourceObject.Load(3, &s.Fields) +} + +func (c *ControlMessageIPPacketInfo) StateTypeName() string { + return "pkg/abi/linux.ControlMessageIPPacketInfo" +} + +func (c *ControlMessageIPPacketInfo) StateFields() []string { + return []string{ + "NIC", + "LocalAddr", + "DestinationAddr", + } +} + +func (c *ControlMessageIPPacketInfo) beforeSave() {} + +// +checklocksignore +func (c *ControlMessageIPPacketInfo) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.NIC) + stateSinkObject.Save(1, &c.LocalAddr) + stateSinkObject.Save(2, &c.DestinationAddr) +} + +func (c *ControlMessageIPPacketInfo) afterLoad() {} + +// +checklocksignore +func (c *ControlMessageIPPacketInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.NIC) + stateSourceObject.Load(1, &c.LocalAddr) + stateSourceObject.Load(2, &c.DestinationAddr) +} + +func (t *KernelTermios) StateTypeName() string { + return "pkg/abi/linux.KernelTermios" +} + +func (t *KernelTermios) StateFields() []string { + return []string{ + "InputFlags", + "OutputFlags", + "ControlFlags", + "LocalFlags", + "LineDiscipline", + "ControlCharacters", + "InputSpeed", + "OutputSpeed", + } +} + +func (t *KernelTermios) beforeSave() {} + +// +checklocksignore +func (t *KernelTermios) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.InputFlags) + stateSinkObject.Save(1, &t.OutputFlags) + stateSinkObject.Save(2, &t.ControlFlags) + stateSinkObject.Save(3, &t.LocalFlags) + stateSinkObject.Save(4, &t.LineDiscipline) + stateSinkObject.Save(5, &t.ControlCharacters) + stateSinkObject.Save(6, &t.InputSpeed) + stateSinkObject.Save(7, &t.OutputSpeed) +} + +func (t *KernelTermios) afterLoad() {} + +// +checklocksignore +func (t *KernelTermios) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.InputFlags) + stateSourceObject.Load(1, &t.OutputFlags) + stateSourceObject.Load(2, &t.ControlFlags) + stateSourceObject.Load(3, &t.LocalFlags) + stateSourceObject.Load(4, &t.LineDiscipline) + stateSourceObject.Load(5, &t.ControlCharacters) + stateSourceObject.Load(6, &t.InputSpeed) + stateSourceObject.Load(7, &t.OutputSpeed) +} + +func (w *WindowSize) StateTypeName() string { + return "pkg/abi/linux.WindowSize" +} + +func (w *WindowSize) StateFields() []string { + return []string{ + "Rows", + "Cols", + } +} + +func (w *WindowSize) beforeSave() {} + +// +checklocksignore +func (w *WindowSize) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.Rows) + stateSinkObject.Save(1, &w.Cols) +} + +func (w *WindowSize) afterLoad() {} + +// +checklocksignore +func (w *WindowSize) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.Rows) + stateSourceObject.Load(1, &w.Cols) +} + +func init() { + state.Register((*IOEvent)(nil)) + state.Register((*BPFInstruction)(nil)) + state.Register((*SigAction)(nil)) + state.Register((*SignalStack)(nil)) + state.Register((*SignalInfo)(nil)) + state.Register((*ControlMessageIPPacketInfo)(nil)) + state.Register((*KernelTermios)(nil)) + state.Register((*WindowSize)(nil)) +} diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go deleted file mode 100644 index 600820a0b..000000000 --- a/pkg/abi/linux/netfilter_test.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package linux - -import ( - "encoding/binary" - "testing" -) - -func TestSizes(t *testing.T) { - testCases := []struct { - typ interface{} - defined uintptr - }{ - {IPTEntry{}, SizeOfIPTEntry}, - {IPTGetEntries{}, SizeOfIPTGetEntries}, - {IPTGetinfo{}, SizeOfIPTGetinfo}, - {IPTIP{}, SizeOfIPTIP}, - {IPTOwnerInfo{}, SizeOfIPTOwnerInfo}, - {IPTReplace{}, SizeOfIPTReplace}, - {XTCounters{}, SizeOfXTCounters}, - {XTEntryMatch{}, SizeOfXTEntryMatch}, - {XTEntryTarget{}, SizeOfXTEntryTarget}, - {XTErrorTarget{}, SizeOfXTErrorTarget}, - {XTStandardTarget{}, SizeOfXTStandardTarget}, - {IP6TReplace{}, SizeOfIP6TReplace}, - {IP6TEntry{}, SizeOfIP6TEntry}, - {IP6TIP{}, SizeOfIP6TIP}, - } - - for _, tc := range testCases { - if calculated := uintptr(binary.Size(tc.typ)); calculated != tc.defined { - t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated) - } - } -} diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD deleted file mode 100644 index 6d8b5f818..000000000 --- a/pkg/amutex/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "amutex", - srcs = ["amutex.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - ], -) - -go_test( - name = "amutex_test", - size = "small", - srcs = ["amutex_test.go"], - library = ":amutex", - deps = ["//pkg/sync"], -) diff --git a/pkg/amutex/amutex_state_autogen.go b/pkg/amutex/amutex_state_autogen.go new file mode 100644 index 000000000..5a09c71ed --- /dev/null +++ b/pkg/amutex/amutex_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package amutex diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go deleted file mode 100644 index 8a3952f2a..000000000 --- a/pkg/amutex/amutex_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package amutex - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -type sleeper struct { - ch chan struct{} -} - -func (s *sleeper) SleepStart() <-chan struct{} { - return s.ch -} - -func (*sleeper) SleepFinish(bool) { -} - -func (s *sleeper) Interrupted() bool { - return len(s.ch) != 0 -} - -func TestMutualExclusion(t *testing.T) { - var m AbortableMutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end of the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock(nil) - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestAbortWait(t *testing.T) { - var s sleeper - var m AbortableMutex - m.Init() - - // Lock the mutex. - m.Lock(&s) - - // Lock again, but this time cancel after 500ms. - s.ch = make(chan struct{}, 1) - go func() { - time.Sleep(500 * time.Millisecond) - s.ch <- struct{}{} - }() - if v := m.Lock(&s); v { - t.Fatalf("Lock succeeded when it should have failed") - } - - // Lock again, but cancel right away. - s.ch <- struct{}{} - if v := m.Lock(&s); v { - t.Fatalf("Lock succeeded when it should have failed") - } -} diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD deleted file mode 100644 index 02c0e52b9..000000000 --- a/pkg/atomicbitops/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "atomicbitops", - srcs = [ - "aligned_32bit_unsafe.go", - "aligned_64bit.go", - "atomicbitops.go", - "atomicbitops_amd64.s", - "atomicbitops_arm64.s", - "atomicbitops_noasm.go", - ], - visibility = ["//:sandbox"], -) - -go_test( - name = "atomicbitops_test", - size = "small", - srcs = [ - "aligned_test.go", - "atomicbitops_test.go", - ], - library = ":atomicbitops", - deps = ["//pkg/sync"], -) diff --git a/pkg/atomicbitops/aligned_test.go b/pkg/atomicbitops/aligned_test.go deleted file mode 100644 index e7123d2b8..000000000 --- a/pkg/atomicbitops/aligned_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package atomicbitops - -import ( - "testing" -) - -func TestAtomiciInt64(t *testing.T) { - v := struct { - v8 int8 - v64 AlignedAtomicInt64 - }{} - v.v64.Add(1) -} - -func TestAtomicUint64(t *testing.T) { - v := struct { - v8 uint8 - v64 AlignedAtomicUint64 - }{} - v.v64.Add(1) -} diff --git a/pkg/atomicbitops/atomicbitops_32bit_unsafe_state_autogen.go b/pkg/atomicbitops/atomicbitops_32bit_unsafe_state_autogen.go new file mode 100644 index 000000000..8241867c7 --- /dev/null +++ b/pkg/atomicbitops/atomicbitops_32bit_unsafe_state_autogen.go @@ -0,0 +1,71 @@ +// automatically generated by stateify. + +//go:build arm || mips || 386 +// +build arm mips 386 + +package atomicbitops + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (aa *AlignedAtomicInt64) StateTypeName() string { + return "pkg/atomicbitops.AlignedAtomicInt64" +} + +func (aa *AlignedAtomicInt64) StateFields() []string { + return []string{ + "value", + "value32", + } +} + +func (aa *AlignedAtomicInt64) beforeSave() {} + +// +checklocksignore +func (aa *AlignedAtomicInt64) StateSave(stateSinkObject state.Sink) { + aa.beforeSave() + stateSinkObject.Save(0, &aa.value) + stateSinkObject.Save(1, &aa.value32) +} + +func (aa *AlignedAtomicInt64) afterLoad() {} + +// +checklocksignore +func (aa *AlignedAtomicInt64) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &aa.value) + stateSourceObject.Load(1, &aa.value32) +} + +func (aa *AlignedAtomicUint64) StateTypeName() string { + return "pkg/atomicbitops.AlignedAtomicUint64" +} + +func (aa *AlignedAtomicUint64) StateFields() []string { + return []string{ + "value", + "value32", + } +} + +func (aa *AlignedAtomicUint64) beforeSave() {} + +// +checklocksignore +func (aa *AlignedAtomicUint64) StateSave(stateSinkObject state.Sink) { + aa.beforeSave() + stateSinkObject.Save(0, &aa.value) + stateSinkObject.Save(1, &aa.value32) +} + +func (aa *AlignedAtomicUint64) afterLoad() {} + +// +checklocksignore +func (aa *AlignedAtomicUint64) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &aa.value) + stateSourceObject.Load(1, &aa.value32) +} + +func init() { + state.Register((*AlignedAtomicInt64)(nil)) + state.Register((*AlignedAtomicUint64)(nil)) +} diff --git a/pkg/atomicbitops/atomicbitops_64bit_state_autogen.go b/pkg/atomicbitops/atomicbitops_64bit_state_autogen.go new file mode 100644 index 000000000..b21caab17 --- /dev/null +++ b/pkg/atomicbitops/atomicbitops_64bit_state_autogen.go @@ -0,0 +1,65 @@ +// automatically generated by stateify. + +//go:build !arm && !mips && !386 +// +build !arm,!mips,!386 + +package atomicbitops + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (aa *AlignedAtomicInt64) StateTypeName() string { + return "pkg/atomicbitops.AlignedAtomicInt64" +} + +func (aa *AlignedAtomicInt64) StateFields() []string { + return []string{ + "value", + } +} + +func (aa *AlignedAtomicInt64) beforeSave() {} + +// +checklocksignore +func (aa *AlignedAtomicInt64) StateSave(stateSinkObject state.Sink) { + aa.beforeSave() + stateSinkObject.Save(0, &aa.value) +} + +func (aa *AlignedAtomicInt64) afterLoad() {} + +// +checklocksignore +func (aa *AlignedAtomicInt64) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &aa.value) +} + +func (aa *AlignedAtomicUint64) StateTypeName() string { + return "pkg/atomicbitops.AlignedAtomicUint64" +} + +func (aa *AlignedAtomicUint64) StateFields() []string { + return []string{ + "value", + } +} + +func (aa *AlignedAtomicUint64) beforeSave() {} + +// +checklocksignore +func (aa *AlignedAtomicUint64) StateSave(stateSinkObject state.Sink) { + aa.beforeSave() + stateSinkObject.Save(0, &aa.value) +} + +func (aa *AlignedAtomicUint64) afterLoad() {} + +// +checklocksignore +func (aa *AlignedAtomicUint64) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &aa.value) +} + +func init() { + state.Register((*AlignedAtomicInt64)(nil)) + state.Register((*AlignedAtomicUint64)(nil)) +} diff --git a/pkg/atomicbitops/atomicbitops_state_autogen.go b/pkg/atomicbitops/atomicbitops_state_autogen.go new file mode 100644 index 000000000..676d44439 --- /dev/null +++ b/pkg/atomicbitops/atomicbitops_state_autogen.go @@ -0,0 +1,8 @@ +// automatically generated by stateify. + +//go:build (amd64 || arm64) && !amd64 && !arm64 +// +build amd64 arm64 +// +build !amd64 +// +build !arm64 + +package atomicbitops diff --git a/pkg/atomicbitops/atomicbitops_test.go b/pkg/atomicbitops/atomicbitops_test.go deleted file mode 100644 index 73af71bb4..000000000 --- a/pkg/atomicbitops/atomicbitops_test.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package atomicbitops - -import ( - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -const iterations = 100 - -func detectRaces32(val, target uint32, fn func(*uint32, uint32)) bool { - runtime.GOMAXPROCS(100) - for n := 0; n < iterations; n++ { - x := val - var wg sync.WaitGroup - for i := uint32(0); i < 32; i++ { - wg.Add(1) - go func(a *uint32, i uint32) { - defer wg.Done() - fn(a, uint32(1<<i)) - }(&x, i) - } - wg.Wait() - if x != target { - return true - } - } - return false -} - -func detectRaces64(val, target uint64, fn func(*uint64, uint64)) bool { - runtime.GOMAXPROCS(100) - for n := 0; n < iterations; n++ { - x := val - var wg sync.WaitGroup - for i := uint64(0); i < 64; i++ { - wg.Add(1) - go func(a *uint64, i uint64) { - defer wg.Done() - fn(a, uint64(1<<i)) - }(&x, i) - } - wg.Wait() - if x != target { - return true - } - } - return false -} - -func TestOrUint32(t *testing.T) { - if detectRaces32(0x0, 0xffffffff, OrUint32) { - t.Error("Data race detected!") - } -} - -func TestAndUint32(t *testing.T) { - if detectRaces32(0xf0f0f0f0, 0x00000000, AndUint32) { - t.Error("Data race detected!") - } -} - -func TestXorUint32(t *testing.T) { - if detectRaces32(0xf0f0f0f0, 0x0f0f0f0f, XorUint32) { - t.Error("Data race detected!") - } -} - -func TestOrUint64(t *testing.T) { - if detectRaces64(0x0, 0xffffffffffffffff, OrUint64) { - t.Error("Data race detected!") - } -} - -func TestAndUint64(t *testing.T) { - if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0, AndUint64) { - t.Error("Data race detected!") - } -} - -func TestXorUint64(t *testing.T) { - if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0f0f0f0f0f0f0f0f, XorUint64) { - t.Error("Data race detected!") - } -} - -func TestCompareAndSwapUint32(t *testing.T) { - tests := []struct { - name string - prev uint32 - old uint32 - new uint32 - next uint32 - }{ - { - name: "Successful compare-and-swap with prev == new", - prev: 10, - old: 10, - new: 10, - next: 10, - }, - { - name: "Successful compare-and-swap with prev != new", - prev: 20, - old: 20, - new: 22, - next: 22, - }, - { - name: "Failed compare-and-swap with prev == new", - prev: 31, - old: 30, - new: 31, - next: 31, - }, - { - name: "Failed compare-and-swap with prev != new", - prev: 41, - old: 40, - new: 42, - next: 41, - }, - } - for _, test := range tests { - val := test.prev - prev := CompareAndSwapUint32(&val, test.old, test.new) - if got, want := prev, test.prev; got != want { - t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want) - } - if got, want := val, test.next; got != want { - t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want) - } - } -} - -func TestCompareAndSwapUint64(t *testing.T) { - tests := []struct { - name string - prev uint64 - old uint64 - new uint64 - next uint64 - }{ - { - name: "Successful compare-and-swap with prev == new", - prev: 0x100000000, - old: 0x100000000, - new: 0x100000000, - next: 0x100000000, - }, - { - name: "Successful compare-and-swap with prev != new", - prev: 0x200000000, - old: 0x200000000, - new: 0x200000002, - next: 0x200000002, - }, - { - name: "Failed compare-and-swap with prev == new", - prev: 0x300000001, - old: 0x300000000, - new: 0x300000001, - next: 0x300000001, - }, - { - name: "Failed compare-and-swap with prev != new", - prev: 0x400000001, - old: 0x400000000, - new: 0x400000002, - next: 0x400000001, - }, - } - for _, test := range tests { - val := test.prev - prev := CompareAndSwapUint64(&val, test.old, test.new) - if got, want := prev, test.prev; got != want { - t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want) - } - if got, want := val, test.next; got != want { - t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want) - } - } -} diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD deleted file mode 100644 index 7ca2fda90..000000000 --- a/pkg/binary/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "binary", - srcs = ["binary.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "binary_test", - size = "small", - srcs = ["binary_test.go"], - library = ":binary", -) diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go deleted file mode 100644 index 25065aef9..000000000 --- a/pkg/binary/binary.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package binary translates between select fixed-sized types and a binary -// representation. -package binary - -import ( - "encoding/binary" - "fmt" - "io" - "reflect" -) - -// LittleEndian is the same as encoding/binary.LittleEndian. -// -// It is included here as a convenience. -var LittleEndian = binary.LittleEndian - -// BigEndian is the same as encoding/binary.BigEndian. -// -// It is included here as a convenience. -var BigEndian = binary.BigEndian - -// AppendUint16 appends the binary representation of a uint16 to buf. -func AppendUint16(buf []byte, order binary.ByteOrder, num uint16) []byte { - buf = append(buf, make([]byte, 2)...) - order.PutUint16(buf[len(buf)-2:], num) - return buf -} - -// AppendUint32 appends the binary representation of a uint32 to buf. -func AppendUint32(buf []byte, order binary.ByteOrder, num uint32) []byte { - buf = append(buf, make([]byte, 4)...) - order.PutUint32(buf[len(buf)-4:], num) - return buf -} - -// AppendUint64 appends the binary representation of a uint64 to buf. -func AppendUint64(buf []byte, order binary.ByteOrder, num uint64) []byte { - buf = append(buf, make([]byte, 8)...) - order.PutUint64(buf[len(buf)-8:], num) - return buf -} - -// Marshal appends a binary representation of data to buf. -// -// data must only contain fixed-length signed and unsigned ints, arrays, -// slices, structs and compositions of said types. data may be a pointer, -// but cannot contain pointers. -func Marshal(buf []byte, order binary.ByteOrder, data interface{}) []byte { - return marshal(buf, order, reflect.Indirect(reflect.ValueOf(data))) -} - -func marshal(buf []byte, order binary.ByteOrder, data reflect.Value) []byte { - switch data.Kind() { - case reflect.Int8: - buf = append(buf, byte(int8(data.Int()))) - case reflect.Int16: - buf = AppendUint16(buf, order, uint16(int16(data.Int()))) - case reflect.Int32: - buf = AppendUint32(buf, order, uint32(int32(data.Int()))) - case reflect.Int64: - buf = AppendUint64(buf, order, uint64(data.Int())) - - case reflect.Uint8: - buf = append(buf, byte(data.Uint())) - case reflect.Uint16: - buf = AppendUint16(buf, order, uint16(data.Uint())) - case reflect.Uint32: - buf = AppendUint32(buf, order, uint32(data.Uint())) - case reflect.Uint64: - buf = AppendUint64(buf, order, data.Uint()) - - case reflect.Array, reflect.Slice: - for i, l := 0, data.Len(); i < l; i++ { - buf = marshal(buf, order, data.Index(i)) - } - - case reflect.Struct: - for i, l := 0, data.NumField(); i < l; i++ { - buf = marshal(buf, order, data.Field(i)) - } - - default: - panic("invalid type: " + data.Type().String()) - } - return buf -} - -// Unmarshal unpacks buf into data. -// -// data must be a slice or a pointer and buf must have a length of exactly -// Size(data). data must only contain fixed-length signed and unsigned ints, -// arrays, slices, structs and compositions of said types. -func Unmarshal(buf []byte, order binary.ByteOrder, data interface{}) { - value := reflect.ValueOf(data) - switch value.Kind() { - case reflect.Ptr: - value = value.Elem() - case reflect.Slice: - default: - panic("invalid type: " + value.Type().String()) - } - buf = unmarshal(buf, order, value) - if len(buf) != 0 { - panic(fmt.Sprintf("buffer too long by %d bytes", len(buf))) - } -} - -func unmarshal(buf []byte, order binary.ByteOrder, data reflect.Value) []byte { - switch data.Kind() { - case reflect.Int8: - data.SetInt(int64(int8(buf[0]))) - buf = buf[1:] - case reflect.Int16: - data.SetInt(int64(int16(order.Uint16(buf)))) - buf = buf[2:] - case reflect.Int32: - data.SetInt(int64(int32(order.Uint32(buf)))) - buf = buf[4:] - case reflect.Int64: - data.SetInt(int64(order.Uint64(buf))) - buf = buf[8:] - - case reflect.Uint8: - data.SetUint(uint64(buf[0])) - buf = buf[1:] - case reflect.Uint16: - data.SetUint(uint64(order.Uint16(buf))) - buf = buf[2:] - case reflect.Uint32: - data.SetUint(uint64(order.Uint32(buf))) - buf = buf[4:] - case reflect.Uint64: - data.SetUint(order.Uint64(buf)) - buf = buf[8:] - - case reflect.Array, reflect.Slice: - for i, l := 0, data.Len(); i < l; i++ { - buf = unmarshal(buf, order, data.Index(i)) - } - - case reflect.Struct: - for i, l := 0, data.NumField(); i < l; i++ { - if field := data.Field(i); field.CanSet() { - buf = unmarshal(buf, order, field) - } else { - buf = buf[sizeof(field):] - } - } - - default: - panic("invalid type: " + data.Type().String()) - } - return buf -} - -// Size calculates the buffer sized needed by Marshal or Unmarshal. -// -// Size only support the types supported by Marshal. -func Size(v interface{}) uintptr { - return sizeof(reflect.Indirect(reflect.ValueOf(v))) -} - -func sizeof(data reflect.Value) uintptr { - switch data.Kind() { - case reflect.Int8, reflect.Uint8: - return 1 - case reflect.Int16, reflect.Uint16: - return 2 - case reflect.Int32, reflect.Uint32: - return 4 - case reflect.Int64, reflect.Uint64: - return 8 - - case reflect.Array, reflect.Slice: - var size uintptr - for i, l := 0, data.Len(); i < l; i++ { - size += sizeof(data.Index(i)) - } - return size - - case reflect.Struct: - var size uintptr - for i, l := 0, data.NumField(); i < l; i++ { - size += sizeof(data.Field(i)) - } - return size - - default: - panic("invalid type: " + data.Type().String()) - } -} - -// ReadUint16 reads a uint16 from r. -func ReadUint16(r io.Reader, order binary.ByteOrder) (uint16, error) { - buf := make([]byte, 2) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - return order.Uint16(buf), nil -} - -// ReadUint32 reads a uint32 from r. -func ReadUint32(r io.Reader, order binary.ByteOrder) (uint32, error) { - buf := make([]byte, 4) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - return order.Uint32(buf), nil -} - -// ReadUint64 reads a uint64 from r. -func ReadUint64(r io.Reader, order binary.ByteOrder) (uint64, error) { - buf := make([]byte, 8) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - return order.Uint64(buf), nil -} - -// WriteUint16 writes a uint16 to w. -func WriteUint16(w io.Writer, order binary.ByteOrder, num uint16) error { - buf := make([]byte, 2) - order.PutUint16(buf, num) - _, err := w.Write(buf) - return err -} - -// WriteUint32 writes a uint32 to w. -func WriteUint32(w io.Writer, order binary.ByteOrder, num uint32) error { - buf := make([]byte, 4) - order.PutUint32(buf, num) - _, err := w.Write(buf) - return err -} - -// WriteUint64 writes a uint64 to w. -func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error { - buf := make([]byte, 8) - order.PutUint64(buf, num) - _, err := w.Write(buf) - return err -} - -// AlignUp rounds a length up to an alignment. align must be a power of 2. -func AlignUp(length int, align uint) int { - return (length + int(align) - 1) & ^(int(align) - 1) -} - -// AlignDown rounds a length down to an alignment. align must be a power of 2. -func AlignDown(length int, align uint) int { - return length & ^(int(align) - 1) -} diff --git a/pkg/binary/binary_test.go b/pkg/binary/binary_test.go deleted file mode 100644 index 4d609a438..000000000 --- a/pkg/binary/binary_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package binary - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "reflect" - "strings" - "testing" -) - -func newInt32(i int32) *int32 { - return &i -} - -func TestSize(t *testing.T) { - if got, want := Size(uint32(10)), uintptr(4); got != want { - t.Errorf("Got = %d, want = %d", got, want) - } -} - -func TestPanic(t *testing.T) { - tests := []struct { - name string - f func([]byte, binary.ByteOrder, interface{}) - data interface{} - want string - }{ - {"Unmarshal int", Unmarshal, 5, "invalid type: int"}, - {"Unmarshal []int", Unmarshal, []int{5}, "invalid type: int"}, - {"Marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, 5, "invalid type: int"}, - {"Marshal int[]", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, []int{5}, "invalid type: int"}, - {"Unmarshal short buffer", Unmarshal, newInt32(5), "runtime error: index out of range"}, - {"Unmarshal long buffer", func(_ []byte, bo binary.ByteOrder, d interface{}) { Unmarshal(make([]byte, 50), bo, d) }, newInt32(5), "buffer too long by 46 bytes"}, - {"marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { marshal(nil, bo, reflect.ValueOf(d)) }, 5, "invalid type: int"}, - {"Size int", func(_ []byte, _ binary.ByteOrder, d interface{}) { Size(d) }, 5, "invalid type: int"}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - defer func() { - r := recover() - if got := fmt.Sprint(r); !strings.HasPrefix(got, test.want) { - t.Errorf("Got recover() = %q, want prefix = %q", got, test.want) - } - }() - - test.f(nil, LittleEndian, test.data) - }) - } -} - -type inner struct { - Field int32 -} - -type outer struct { - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - - Slice []int32 - Array [5]int32 - Struct inner -} - -func TestMarshalUnmarshal(t *testing.T) { - want := outer{ - 1, 2, 3, 4, 5, 6, 7, 8, - []int32{9, 10, 11}, - [5]int32{12, 13, 14, 15, 16}, - inner{17}, - } - buf := Marshal(nil, LittleEndian, want) - got := outer{Slice: []int32{0, 0, 0}} - Unmarshal(buf, LittleEndian, &got) - if !reflect.DeepEqual(&got, &want) { - t.Errorf("Got = %#v, want = %#v", got, want) - } -} - -type outerBenchmark struct { - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - - Array [5]int32 - Struct inner -} - -func BenchmarkMarshalUnmarshal(b *testing.B) { - b.ReportAllocs() - - in := outerBenchmark{ - 1, 2, 3, 4, 5, 6, 7, 8, - [5]int32{9, 10, 11, 12, 13}, - inner{14}, - } - buf := make([]byte, Size(&in)) - out := outerBenchmark{} - - for i := 0; i < b.N; i++ { - buf := Marshal(buf[:0], LittleEndian, &in) - Unmarshal(buf, LittleEndian, &out) - } -} - -func BenchmarkReadWrite(b *testing.B) { - b.ReportAllocs() - - in := outerBenchmark{ - 1, 2, 3, 4, 5, 6, 7, 8, - [5]int32{9, 10, 11, 12, 13}, - inner{14}, - } - buf := bytes.NewBuffer(make([]byte, binary.Size(&in))) - out := outerBenchmark{} - - for i := 0; i < b.N; i++ { - buf.Reset() - if err := binary.Write(buf, LittleEndian, &in); err != nil { - b.Error("Write:", err) - } - if err := binary.Read(buf, LittleEndian, &out); err != nil { - b.Error("Read:", err) - } - } -} - -type outerPadding struct { - _ int8 - _ int16 - _ int32 - _ int64 - _ uint8 - _ uint16 - _ uint32 - _ uint64 - - _ []int32 - _ [5]int32 - _ inner -} - -func TestMarshalUnmarshalPadding(t *testing.T) { - var want outerPadding - buf := Marshal(nil, LittleEndian, want) - var got outerPadding - Unmarshal(buf, LittleEndian, &got) - if !reflect.DeepEqual(&got, &want) { - t.Errorf("Got = %#v, want = %#v", got, want) - } -} - -// Numbers with bits in every byte that distinguishable in big and little endian. -const ( - want16 = 64<<8 | 128 - want32 = 16<<24 | 32<<16 | want16 - want64 = 1<<56 | 2<<48 | 4<<40 | 8<<32 | want32 -) - -func TestReadWriteUint16(t *testing.T) { - const want = uint16(want16) - var buf bytes.Buffer - if err := WriteUint16(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint16:", err) - } - got, err := ReadUint16(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint16:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestReadWriteUint32(t *testing.T) { - const want = uint32(want32) - var buf bytes.Buffer - if err := WriteUint32(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint32:", err) - } - got, err := ReadUint32(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint32:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestReadWriteUint64(t *testing.T) { - const want = uint64(want64) - var buf bytes.Buffer - if err := WriteUint64(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint64:", err) - } - got, err := ReadUint64(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint64:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -type readWriter struct { - err error -} - -func (rw *readWriter) Write([]byte) (int, error) { - return 0, rw.err -} - -func (rw *readWriter) Read([]byte) (int, error) { - return 0, rw.err -} - -func TestReadWriteError(t *testing.T) { - tests := []struct { - name string - f func(rw io.ReadWriter) error - }{ - {"WriteUint16", func(rw io.ReadWriter) error { return WriteUint16(rw, LittleEndian, 0) }}, - {"ReadUint16", func(rw io.ReadWriter) error { _, err := ReadUint16(rw, LittleEndian); return err }}, - {"WriteUint32", func(rw io.ReadWriter) error { return WriteUint32(rw, LittleEndian, 0) }}, - {"ReadUint32", func(rw io.ReadWriter) error { _, err := ReadUint32(rw, LittleEndian); return err }}, - {"WriteUint64", func(rw io.ReadWriter) error { return WriteUint64(rw, LittleEndian, 0) }}, - {"ReadUint64", func(rw io.ReadWriter) error { _, err := ReadUint64(rw, LittleEndian); return err }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - want := errors.New("want") - if got := test.f(&readWriter{want}); got != want { - t.Errorf("got = %v, want = %v", got, want) - } - }) - } -} diff --git a/pkg/bitmap/BUILD b/pkg/bitmap/BUILD deleted file mode 100644 index 0f1e7006d..000000000 --- a/pkg/bitmap/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "bitmap", - srcs = ["bitmap.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "bitmap_test", - size = "small", - srcs = ["bitmap_test.go"], - library = ":bitmap", -) diff --git a/pkg/bitmap/bitmap_state_autogen.go b/pkg/bitmap/bitmap_state_autogen.go new file mode 100644 index 000000000..132b27a56 --- /dev/null +++ b/pkg/bitmap/bitmap_state_autogen.go @@ -0,0 +1,39 @@ +// automatically generated by stateify. + +package bitmap + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (b *Bitmap) StateTypeName() string { + return "pkg/bitmap.Bitmap" +} + +func (b *Bitmap) StateFields() []string { + return []string{ + "numOnes", + "bitBlock", + } +} + +func (b *Bitmap) beforeSave() {} + +// +checklocksignore +func (b *Bitmap) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.numOnes) + stateSinkObject.Save(1, &b.bitBlock) +} + +func (b *Bitmap) afterLoad() {} + +// +checklocksignore +func (b *Bitmap) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.numOnes) + stateSourceObject.Load(1, &b.bitBlock) +} + +func init() { + state.Register((*Bitmap)(nil)) +} diff --git a/pkg/bitmap/bitmap_test.go b/pkg/bitmap/bitmap_test.go deleted file mode 100644 index 76ebd779f..000000000 --- a/pkg/bitmap/bitmap_test.go +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bitmap - -import ( - "math" - "reflect" - "testing" -) - -// generateFilledSlice generates a slice fill with numbers. -func generateFilledSlice(min, max, length int) []uint32 { - if max == min { - return []uint32{uint32(min)} - } - if length > (max - min) { - return nil - } - randSlice := make([]uint32, length) - if length != 0 { - rangeNum := uint32((max - min) / length) - randSlice[0], randSlice[length-1] = uint32(min), uint32(max) - for i := 1; i < length-1; i++ { - randSlice[i] = randSlice[i-1] + rangeNum - } - } - return randSlice -} - -// generateFilledBitmap generates a Bitmap filled with fillNum of numbers, -// and returns the slice and bitmap. -func generateFilledBitmap(min, max, fillNum int) ([]uint32, Bitmap) { - bitmap := New(uint32(max)) - randSlice := generateFilledSlice(min, max, fillNum) - for i := 0; i < fillNum; i++ { - bitmap.Add(randSlice[i]) - } - return randSlice, bitmap -} - -func TestNewBitmap(t *testing.T) { - tests := []struct { - name string - size int - expectSize int - }{ - {"length 1", 1, 1}, - {"length 1024", 1024, 16}, - {"length 1025", 1025, 17}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - if bitmap := New(uint32(tt.size)); len(bitmap.bitBlock) != tt.expectSize { - t.Errorf("New created bitmap with %v, bitBlock size: %d, wanted: %d", tt.name, len(bitmap.bitBlock), tt.expectSize) - } - }) - } -} - -func TestAdd(t *testing.T) { - tests := []struct { - name string - bitmapSize int - addNum int - }{ - {"Add with null bitmap.bitBlock", 0, 10}, - {"Add without extending bitBlock", 64, 10}, - {"Add without extending bitblock with margin number", 63, 64}, - {"Add with extended one block", 1024, 1025}, - {"Add with extended more then one block", 1024, 2048}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - bitmap := New(uint32(tt.bitmapSize)) - bitmap.Add(uint32(tt.addNum)) - bitmapSlice := bitmap.ToSlice() - if bitmapSlice[0] != uint32(tt.addNum) { - t.Errorf("%v, get number: %d, wanted: %d.", tt.name, bitmapSlice[0], tt.addNum) - } - }) - } -} - -func TestRemove(t *testing.T) { - bitmap := New(uint32(1024)) - firstSlice := generateFilledSlice(0, 511, 50) - secondSlice := generateFilledSlice(512, 1024, 50) - for i := 0; i < 50; i++ { - bitmap.Add(firstSlice[i]) - bitmap.Add(secondSlice[i]) - } - - for i := 0; i < 50; i++ { - bitmap.Remove(firstSlice[i]) - } - bitmapSlice := bitmap.ToSlice() - if !reflect.DeepEqual(bitmapSlice, secondSlice) { - t.Errorf("After Remove() firstSlice, remained slice: %v, wanted: %v", bitmapSlice, secondSlice) - } - - for i := 0; i < 50; i++ { - bitmap.Remove(secondSlice[i]) - } - bitmapSlice = bitmap.ToSlice() - emptySlice := make([]uint32, 0) - if !reflect.DeepEqual(bitmapSlice, emptySlice) { - t.Errorf("After Remove secondSlice, remained slice: %v, wanted: %v", bitmapSlice, emptySlice) - } - -} - -// Verify flip bits within one bitBlock, one bit and bits cross multi bitBlocks. -func TestFlipRange(t *testing.T) { - tests := []struct { - name string - flipRangeMin int - flipRangeMax int - filledSliceLen int - }{ - {"Flip one number in bitmap", 77, 77, 1}, - {"Flip numbers within one bitBlocks", 5, 60, 20}, - {"Flip numbers that cross multi bitBlocks", 20, 1000, 300}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - fillSlice, bitmap := generateFilledBitmap(tt.flipRangeMin, tt.flipRangeMax, tt.filledSliceLen) - flipFillSlice := make([]uint32, 0) - for i, j := tt.flipRangeMin, 0; i <= tt.flipRangeMax; i++ { - if uint32(i) != fillSlice[j] { - flipFillSlice = append(flipFillSlice, uint32(i)) - } else { - j++ - } - } - - bitmap.FlipRange(uint32(tt.flipRangeMin), uint32(tt.flipRangeMax+1)) - flipBitmapSlice := bitmap.ToSlice() - if !reflect.DeepEqual(flipFillSlice, flipBitmapSlice) { - t.Errorf("%v, flipped slice: %v, wanted: %v", tt.name, flipBitmapSlice, flipFillSlice) - } - }) - } -} - -// Verify clear bits within one bitBlock, one bit and bits cross multi bitBlocks. -func TestClearRange(t *testing.T) { - tests := []struct { - name string - clearRangeMin int - clearRangeMax int - bitmapSize int - }{ - {"ClearRange clear one number", 5, 5, 64}, - {"ClearRange clear numbers within one bitBlock", 4, 61, 64}, - {"ClearRange clear numbers cross multi bitBlocks", 20, 254, 512}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - bitmap := New(uint32(tt.bitmapSize)) - bitmap.FlipRange(uint32(0), uint32(tt.bitmapSize)) - bitmap.ClearRange(uint32(tt.clearRangeMin), uint32(tt.clearRangeMax+1)) - clearedBitmapSlice := bitmap.ToSlice() - clearedSlice := make([]uint32, 0) - for i := 0; i < tt.bitmapSize; i++ { - if i < tt.clearRangeMin || i > tt.clearRangeMax { - clearedSlice = append(clearedSlice, uint32(i)) - } - } - if !reflect.DeepEqual(clearedSlice, clearedBitmapSlice) { - t.Errorf("%v, cleared slice: %v, wanted: %v", tt.name, clearedBitmapSlice, clearedSlice) - } - }) - - } -} - -func TestMinimum(t *testing.T) { - randSlice, bitmap := generateFilledBitmap(0, 1024, 200) - min := bitmap.Minimum() - if min != randSlice[0] { - t.Errorf("Minimum() returns: %v, wanted: %v", min, randSlice[0]) - } - - bitmap.ClearRange(uint32(0), uint32(200)) - min = bitmap.Minimum() - bitmapSlice := bitmap.ToSlice() - if min != bitmapSlice[0] { - t.Errorf("After ClearRange, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0]) - } - - bitmap.FlipRange(uint32(2), uint32(11)) - min = bitmap.Minimum() - bitmapSlice = bitmap.ToSlice() - if min != bitmapSlice[0] { - t.Errorf("After Flip, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0]) - } -} - -func TestMaximum(t *testing.T) { - randSlice, bitmap := generateFilledBitmap(0, 1024, 200) - max := bitmap.Maximum() - if max != randSlice[len(randSlice)-1] { - t.Errorf("Maximum() returns: %v, wanted: %v", max, randSlice[len(randSlice)-1]) - } - - bitmap.ClearRange(uint32(1000), uint32(1025)) - max = bitmap.Maximum() - bitmapSlice := bitmap.ToSlice() - if max != bitmapSlice[len(bitmapSlice)-1] { - t.Errorf("After ClearRange, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1]) - } - - bitmap.FlipRange(uint32(1001), uint32(1021)) - max = bitmap.Maximum() - bitmapSlice = bitmap.ToSlice() - if max != bitmapSlice[len(bitmapSlice)-1] { - t.Errorf("After Flip, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1]) - } -} - -func TestBitmapNumOnes(t *testing.T) { - randSlice, bitmap := generateFilledBitmap(0, 1024, 200) - bitmapOnes := bitmap.GetNumOnes() - if bitmapOnes != uint32(200) { - t.Errorf("GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) - } - // Remove 10 numbers. - for i := 5; i < 15; i++ { - bitmap.Remove(randSlice[i]) - } - bitmapOnes = bitmap.GetNumOnes() - if bitmapOnes != uint32(190) { - t.Errorf("After Remove 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190) - } - // Remove the 10 number again, the length supposed not change. - for i := 5; i < 15; i++ { - bitmap.Remove(randSlice[i]) - } - bitmapOnes = bitmap.GetNumOnes() - if bitmapOnes != uint32(190) { - t.Errorf("After Remove the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190) - } - - // Add 10 number. - for i := 1080; i < 1090; i++ { - bitmap.Add(uint32(i)) - } - bitmapOnes = bitmap.GetNumOnes() - if bitmapOnes != uint32(200) { - t.Errorf("After Add 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) - } - - // Add the 10 number again, the length supposed not change. - for i := 1080; i < 1090; i++ { - bitmap.Add(uint32(i)) - } - bitmapOnes = bitmap.GetNumOnes() - if bitmapOnes != uint32(200) { - t.Errorf("After Add the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) - } - - // Flip 10 bits from 0 to 1. - bitmap.FlipRange(uint32(1025), uint32(1035)) - bitmapOnes = bitmap.GetNumOnes() - if bitmapOnes != uint32(210) { - t.Errorf("After Flip, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 210) - } - - // ClearRange numbers range from [0, 1025). - bitmap.ClearRange(uint32(0), uint32(1025)) - bitmapOnes = bitmap.GetNumOnes() - if bitmapOnes != uint32(20) { - t.Errorf("After ClearRange, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 20) - } -} - -func TestFirstZero(t *testing.T) { - bitmap := New(uint32(1000)) - bitmap.FlipRange(200, 400) - for i, j := range map[uint32]uint32{0: 0, 201: 400, 200: 400, 199: 199, 400: 400, 10000: math.MaxInt32} { - v := bitmap.FirstZero(i) - if v != j { - t.Errorf("Minimum() returns: %v, wanted: %v", v, j) - } - } -} diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD deleted file mode 100644 index 63f4670d7..000000000 --- a/pkg/bits/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_library( - name = "bits", - srcs = [ - "bits.go", - "bits32.go", - "bits64.go", - "uint64_arch.go", - "uint64_arch_amd64_asm.s", - "uint64_arch_arm64_asm.s", - "uint64_arch_generic.go", - ], - visibility = ["//:sandbox"], -) - -go_template( - name = "bits_template", - srcs = ["bits_template.go"], - types = [ - "T", - ], -) - -go_template_instance( - name = "bits64", - out = "bits64.go", - package = "bits", - suffix = "64", - template = ":bits_template", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "bits32", - out = "bits32.go", - package = "bits", - suffix = "32", - template = ":bits_template", - types = { - "T": "uint32", - }, -) - -go_test( - name = "bits_test", - size = "small", - srcs = ["uint64_test.go"], - library = ":bits", -) diff --git a/pkg/bits/bits32.go b/pkg/bits/bits32.go new file mode 100644 index 000000000..28134a9e7 --- /dev/null +++ b/pkg/bits/bits32.go @@ -0,0 +1,33 @@ +package bits + +// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. +func IsOn32(mask, bits uint32) bool { + return mask&bits == bits +} + +// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. +func IsAnyOn32(mask, bits uint32) bool { + return mask&bits != 0 +} + +// Mask returns a T with all of the given bits set. +func Mask32(is ...int) uint32 { + ret := uint32(0) + for _, i := range is { + ret |= MaskOf32(i) + } + return ret +} + +// MaskOf is like Mask, but sets only a single bit (more efficiently). +func MaskOf32(i int) uint32 { + return uint32(1) << uint32(i) +} + +// IsPowerOfTwo returns true if v is power of 2. +func IsPowerOfTwo32(v uint32) bool { + if v == 0 { + return false + } + return v&(v-1) == 0 +} diff --git a/pkg/bits/bits64.go b/pkg/bits/bits64.go new file mode 100644 index 000000000..73117b19b --- /dev/null +++ b/pkg/bits/bits64.go @@ -0,0 +1,33 @@ +package bits + +// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. +func IsOn64(mask, bits uint64) bool { + return mask&bits == bits +} + +// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. +func IsAnyOn64(mask, bits uint64) bool { + return mask&bits != 0 +} + +// Mask returns a T with all of the given bits set. +func Mask64(is ...int) uint64 { + ret := uint64(0) + for _, i := range is { + ret |= MaskOf64(i) + } + return ret +} + +// MaskOf is like Mask, but sets only a single bit (more efficiently). +func MaskOf64(i int) uint64 { + return uint64(1) << uint64(i) +} + +// IsPowerOfTwo returns true if v is power of 2. +func IsPowerOfTwo64(v uint64) bool { + if v == 0 { + return false + } + return v&(v-1) == 0 +} diff --git a/pkg/bits/bits_state_autogen.go b/pkg/bits/bits_state_autogen.go new file mode 100644 index 000000000..436c111bd --- /dev/null +++ b/pkg/bits/bits_state_autogen.go @@ -0,0 +1,8 @@ +// automatically generated by stateify. + +//go:build (amd64 || arm64) && !amd64 && !arm64 +// +build amd64 arm64 +// +build !amd64 +// +build !arm64 + +package bits diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go deleted file mode 100644 index 998645388..000000000 --- a/pkg/bits/bits_template.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bits - -// Non-atomic bit operations on a template type T. - -// T is a required type parameter that must be an integral type. -type T uint64 - -// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. -func IsOn(mask, bits T) bool { - return mask&bits == bits -} - -// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. -func IsAnyOn(mask, bits T) bool { - return mask&bits != 0 -} - -// Mask returns a T with all of the given bits set. -func Mask(is ...int) T { - ret := T(0) - for _, i := range is { - ret |= MaskOf(i) - } - return ret -} - -// MaskOf is like Mask, but sets only a single bit (more efficiently). -func MaskOf(i int) T { - return T(1) << T(i) -} - -// IsPowerOfTwo returns true if v is power of 2. -func IsPowerOfTwo(v T) bool { - if v == 0 { - return false - } - return v&(v-1) == 0 -} diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go deleted file mode 100644 index 193d1ebcd..000000000 --- a/pkg/bits/uint64_test.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bits - -import ( - "reflect" - "testing" -) - -func TestTrailingZeros64(t *testing.T) { - for i := 0; i <= 64; i++ { - n := uint64(1) << uint(i) - if got, want := TrailingZeros64(n), i; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) << uint(i) - if got, want := TrailingZeros64(n), i; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) >> uint(i) - if got, want := TrailingZeros64(n), 0; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } -} - -func TestMostSignificantOne64(t *testing.T) { - for i := 0; i <= 64; i++ { - n := uint64(1) << uint(i) - if got, want := MostSignificantOne64(n), i; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) >> uint(i) - if got, want := MostSignificantOne64(n), 63-i; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) << uint(i) - if got, want := MostSignificantOne64(n), 63; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } -} - -func TestForEachSetBit64(t *testing.T) { - for _, want := range [][]int{ - {}, - {0}, - {1}, - {63}, - {0, 1}, - {1, 3, 5}, - {0, 63}, - } { - n := Mask64(want...) - // "Slice values are deeply equal when ... they are both nil or both - // non-nil ..." - got := make([]int, 0) - ForEachSetBit64(n, func(i int) { - got = append(got, i) - }) - if !reflect.DeepEqual(got, want) { - t.Errorf("ForEachSetBit64(%#x): iterated bits %v, wanted %v", n, got, want) - } - } -} - -func TestIsOn(t *testing.T) { - type spec struct { - mask uint64 - bits uint64 - any bool - all bool - } - for _, s := range []spec{ - {Mask64(0), Mask64(0), true, true}, - {Mask64(63), Mask64(63), true, true}, - {Mask64(0), Mask64(1), false, false}, - {Mask64(0), Mask64(0, 1), true, false}, - - {Mask64(1, 63), Mask64(1), true, true}, - {Mask64(1, 63), Mask64(1, 63), true, true}, - {Mask64(1, 63), Mask64(0, 1, 63), true, false}, - {Mask64(1, 63), Mask64(0, 62), false, false}, - } { - if ok := IsAnyOn64(s.mask, s.bits); ok != s.any { - t.Errorf("IsAnyOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.any) - } - if ok := IsOn64(s.mask, s.bits); ok != s.all { - t.Errorf("IsOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.all) - } - } -} - -func TestIsPowerOfTwo(t *testing.T) { - for _, tc := range []struct { - v uint64 - want bool - }{ - {v: 0, want: false}, - {v: 1, want: true}, - {v: 2, want: true}, - {v: 3, want: false}, - {v: 4, want: true}, - {v: 5, want: false}, - } { - if got := IsPowerOfTwo64(tc.v); got != tc.want { - t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want) - } - } -} diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD deleted file mode 100644 index c17390522..000000000 --- a/pkg/bpf/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "bpf", - srcs = [ - "bpf.go", - "decoder.go", - "input_bytes.go", - "interpreter.go", - "program_builder.go", - ], - visibility = ["//visibility:public"], - deps = ["//pkg/abi/linux"], -) - -go_test( - name = "bpf_test", - size = "small", - srcs = [ - "decoder_test.go", - "interpreter_test.go", - "program_builder_test.go", - ], - library = ":bpf", - deps = [ - "//pkg/abi/linux", - "//pkg/hostarch", - "//pkg/marshal", - ], -) diff --git a/pkg/bpf/bpf_state_autogen.go b/pkg/bpf/bpf_state_autogen.go new file mode 100644 index 000000000..3a020fe29 --- /dev/null +++ b/pkg/bpf/bpf_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package bpf + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *Program) StateTypeName() string { + return "pkg/bpf.Program" +} + +func (p *Program) StateFields() []string { + return []string{ + "instructions", + } +} + +func (p *Program) beforeSave() {} + +// +checklocksignore +func (p *Program) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.instructions) +} + +func (p *Program) afterLoad() {} + +// +checklocksignore +func (p *Program) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.instructions) +} + +func init() { + state.Register((*Program)(nil)) +} diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go deleted file mode 100644 index bb971ce21..000000000 --- a/pkg/bpf/decoder_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" -) - -func TestDecode(t *testing.T) { - for _, test := range []struct { - filter linux.BPFInstruction - expected string - fail bool - }{ - {filter: Stmt(Ld+Imm, 10), expected: "A <- 10"}, - {filter: Stmt(Ld+Abs+W, 10), expected: "A <- P[10:4]"}, - {filter: Stmt(Ld+Ind+H, 10), expected: "A <- P[X+10:2]"}, - {filter: Stmt(Ld+Ind+B, 10), expected: "A <- P[X+10:1]"}, - {filter: Stmt(Ld+Mem, 10), expected: "A <- M[10]"}, - {filter: Stmt(Ld+Len, 0), expected: "A <- len"}, - {filter: Stmt(Ldx+Imm, 10), expected: "X <- 10"}, - {filter: Stmt(Ldx+Mem, 10), expected: "X <- M[10]"}, - {filter: Stmt(Ldx+Len, 0), expected: "X <- len"}, - {filter: Stmt(Ldx+Msh, 10), expected: "X <- 4*(P[10:1]&0xf)"}, - {filter: Stmt(St, 10), expected: "M[10] <- A"}, - {filter: Stmt(Stx, 10), expected: "M[10] <- X"}, - {filter: Stmt(Alu+Add+K, 10), expected: "A <- A + 10"}, - {filter: Stmt(Alu+Sub+K, 10), expected: "A <- A - 10"}, - {filter: Stmt(Alu+Mul+K, 10), expected: "A <- A * 10"}, - {filter: Stmt(Alu+Div+K, 10), expected: "A <- A / 10"}, - {filter: Stmt(Alu+Or+K, 10), expected: "A <- A | 10"}, - {filter: Stmt(Alu+And+K, 10), expected: "A <- A & 10"}, - {filter: Stmt(Alu+Lsh+K, 10), expected: "A <- A << 10"}, - {filter: Stmt(Alu+Rsh+K, 10), expected: "A <- A >> 10"}, - {filter: Stmt(Alu+Mod+K, 10), expected: "A <- A % 10"}, - {filter: Stmt(Alu+Xor+K, 10), expected: "A <- A ^ 10"}, - {filter: Stmt(Alu+Add+X, 0), expected: "A <- A + X"}, - {filter: Stmt(Alu+Sub+X, 0), expected: "A <- A - X"}, - {filter: Stmt(Alu+Mul+X, 0), expected: "A <- A * X"}, - {filter: Stmt(Alu+Div+X, 0), expected: "A <- A / X"}, - {filter: Stmt(Alu+Or+X, 0), expected: "A <- A | X"}, - {filter: Stmt(Alu+And+X, 0), expected: "A <- A & X"}, - {filter: Stmt(Alu+Lsh+X, 0), expected: "A <- A << X"}, - {filter: Stmt(Alu+Rsh+X, 0), expected: "A <- A >> X"}, - {filter: Stmt(Alu+Mod+X, 0), expected: "A <- A % X"}, - {filter: Stmt(Alu+Xor+X, 0), expected: "A <- A ^ X"}, - {filter: Stmt(Alu+Neg, 0), expected: "A <- -A"}, - {filter: Stmt(Jmp+Ja, 10), expected: "pc += 10"}, - {filter: Jump(Jmp+Jeq+K, 10, 2, 5), expected: "pc += (A == 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jgt+K, 10, 2, 5), expected: "pc += (A > 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jge+K, 10, 2, 5), expected: "pc += (A >= 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jset+K, 10, 2, 5), expected: "pc += (A & 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jeq+X, 0, 2, 5), expected: "pc += (A == X) ? 2 : 5"}, - {filter: Jump(Jmp+Jgt+X, 0, 2, 5), expected: "pc += (A > X) ? 2 : 5"}, - {filter: Jump(Jmp+Jge+X, 0, 2, 5), expected: "pc += (A >= X) ? 2 : 5"}, - {filter: Jump(Jmp+Jset+X, 0, 2, 5), expected: "pc += (A & X) ? 2 : 5"}, - {filter: Stmt(Ret+K, 10), expected: "ret 10"}, - {filter: Stmt(Ret+A, 0), expected: "ret A"}, - {filter: Stmt(Misc+Tax, 0), expected: "X <- A"}, - {filter: Stmt(Misc+Txa, 0), expected: "A <- X"}, - {filter: Stmt(Ld+Ind+Msh, 0), fail: true}, - } { - got, err := Decode(test.filter) - if test.fail { - if err == nil { - t.Errorf("Decode(%v) failed, expected: 'error', got: %q", test.filter, got) - continue - } - } else { - if err != nil { - t.Errorf("Decode(%v) failed for test %q, error: %q", test.filter, test.expected, err) - continue - } - if got != test.expected { - t.Errorf("Decode(%v) failed, expected: %q, got: %q", test.filter, test.expected, got) - continue - } - } - } -} - -func TestDecodeInstructions(t *testing.T) { - for _, test := range []struct { - name string - program []linux.BPFInstruction - expected string - fail bool - }{ - {name: "basic with jump indexes", - program: []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Stmt(Ldx+Mem, 10), - Stmt(St, 10), - Stmt(Stx, 10), - Stmt(Alu+Add+K, 10), - Stmt(Jmp+Ja, 10), - Jump(Jmp+Jeq+K, 10, 2, 5), - Jump(Jmp+Jset+X, 0, 0, 5), - Stmt(Misc+Tax, 0), - }, - expected: "0: A <- P[10:4]\n" + - "1: X <- M[10]\n" + - "2: M[10] <- A\n" + - "3: M[10] <- X\n" + - "4: A <- A + 10\n" + - "5: pc += 10 [16]\n" + - "6: pc += (A == 10) ? 2 [9] : 5 [12]\n" + - "7: pc += (A & X) ? 0 [8] : 5 [13]\n" + - "8: X <- A\n", - }, - {name: "invalid instruction", - program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)}, - fail: true}, - } { - got, err := DecodeInstructions(test.program) - if test.fail { - if err == nil { - t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got) - continue - } - } else { - if err != nil { - t.Errorf("%s: Decode failed: %v", test.name, err) - continue - } - if got != test.expected { - t.Errorf("%s: Decode(...) failed, expected: %q, got: %q", test.name, test.expected, got) - continue - } - } - } -} diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go deleted file mode 100644 index f64a2dc50..000000000 --- a/pkg/bpf/interpreter_test.go +++ /dev/null @@ -1,799 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "encoding/binary" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/marshal" -) - -func TestCompilationErrors(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be compiled. - insns []linux.BPFInstruction - - // expectedErr is the expected compilation error. - expectedErr error - }{ - { - desc: "Instructions must not be nil", - expectedErr: Error{InvalidInstructionCount, 0}, - }, - { - desc: "Instructions must not be empty", - insns: []linux.BPFInstruction{}, - expectedErr: Error{InvalidInstructionCount, 0}, - }, - { - desc: "A program must end with a return", - insns: make([]linux.BPFInstruction, MaxInstructions), - expectedErr: Error{InvalidEndOfProgram, MaxInstructions - 1}, - }, - { - desc: "A program must have MaxInstructions or fewer instructions", - insns: append(make([]linux.BPFInstruction, MaxInstructions), Stmt(Ret|K, 0)), - expectedErr: Error{InvalidInstructionCount, MaxInstructions + 1}, - }, - { - desc: "A load from an invalid M register is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(Ld|Mem|W, ScratchMemRegisters), // A = M[16] - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidRegister, 0}, - }, - { - desc: "A store to an invalid M register is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(St, ScratchMemRegisters), // M[16] = A - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidRegister, 0}, - }, - { - desc: "Division by literal zero is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Div|K, 0), // A /= 0 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - { - desc: "An unconditional jump outside of the program is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - { - desc: "A conditional jump outside of the program in the true case is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 0, 1, 0), // if (A == K) jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - { - desc: "A conditional jump outside of the program in the false case is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 0, 0, 1), // if (A != K) jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - } { - _, err := Compile(test.insns) - if err != test.expectedErr { - t.Errorf("%s: expected error %q, got error %q", test.desc, test.expectedErr, err) - } - } -} - -func TestExecErrors(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be executed. - insns []linux.BPFInstruction - - // expectedErr is the expected execution error. - expectedErr error - }{ - { - desc: "An out-of-bounds load of input data is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|B, 0), // A = input[0] - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidLoad, 0}, - }, - { - desc: "Division by zero at runtime is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Div|X, 0), // A /= X - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - { - desc: "Modulo zero at runtime is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Mod|X, 0), // A %= X - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - } { - p, err := Compile(test.insns) - if err != nil { - t.Errorf("%s: unexpected compilation error: %v", test.desc, err) - continue - } - ret, err := Exec(p, InputBytes{nil, binary.BigEndian}) - if err != test.expectedErr { - t.Errorf("%s: expected execution error %q, got (%d, %v)", test.desc, test.expectedErr, ret, err) - } - } -} - -func TestValidInstructions(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be compiled. - insns []linux.BPFInstruction - - // input is the input data. Note that input will be read as big-endian. - input []byte - - // expectedRet is the expected return value of the BPF program. - expectedRet uint32 - }{ - { - desc: "Return of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ret|K, 42), // return 42 - }, - expectedRet: 42, - }, - { - desc: "Load of immediate into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of immediate into X and copying of X into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 42), // X = 42 - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Copying of A into X and back", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Misc|Txa, 0), // X = A - Stmt(Ld|Imm|W, 0), // A = 0 - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of 32-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|W, 1), // A = input[1..4] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33, 0x44}, - expectedRet: 0x11223344, - }, - { - desc: "Load of 16-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|H, 1), // A = input[1..2] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22}, - expectedRet: 0x1122, - }, - { - desc: "Load of 8-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|B, 1), // A = input[1] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11}, - expectedRet: 0x11, - }, - { - desc: "Load of 32-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|W, 1), // A = input[X+1..X+4] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - expectedRet: 0x22334455, - }, - { - desc: "Load of 16-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|H, 1), // A = input[X+1..X+2] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33}, - expectedRet: 0x2233, - }, - { - desc: "Load of 8-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|B, 1), // A = input[X+1] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22}, - expectedRet: 0x22, - }, - { - desc: "Load/store between A and scratch memory", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(St, 2), // M[2] = A - Stmt(Ld|Imm|W, 0), // A = 0 - Stmt(Ld|Mem|W, 2), // A = M[2] - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load/store between X and scratch memory", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 42), // X = 42 - Stmt(Stx, 3), // M[3] = X - Stmt(Ldx|Imm|W, 0), // X = 0 - Stmt(Ldx|Mem|W, 3), // X = M[3] - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of input length into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Len|W, 0), // A = len(input) - Stmt(Ret|A, 0), // return A - }, - input: []byte{1, 2, 3}, - expectedRet: 3, - }, - { - desc: "Load of input length into X", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Len|W, 0), // X = len(input) - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - input: []byte{1, 2, 3}, - expectedRet: 3, - }, - { - desc: "Load of MSH (?) into X", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Msh|B, 0), // X = 4*(input[0]&0xf) - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - input: []byte{0xf1}, - expectedRet: 4, - }, - { - desc: "Addition of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Alu|Add|K, 20), // A += 20 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 30, - }, - { - desc: "Addition of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 20), // X = 20 - Stmt(Alu|Add|X, 0), // A += X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 30, - }, - { - desc: "Subtraction of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 30), // A = 30 - Stmt(Alu|Sub|K, 20), // A -= 20 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 10, - }, - { - desc: "Subtraction of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 30), // A = 30 - Stmt(Ldx|Imm|W, 20), // X = 20 - Stmt(Alu|Sub|X, 0), // A -= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 10, - }, - { - desc: "Multiplication of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 2), // A = 2 - Stmt(Alu|Mul|K, 3), // A *= 3 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 6, - }, - { - desc: "Multiplication of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 2), // A = 2 - Stmt(Ldx|Imm|W, 3), // X = 3 - Stmt(Alu|Mul|X, 0), // A *= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 6, - }, - { - desc: "Division by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 6), // A = 6 - Stmt(Alu|Div|K, 3), // A /= 3 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 2, - }, - { - desc: "Division by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 6), // A = 6 - Stmt(Ldx|Imm|W, 3), // X = 3 - Stmt(Alu|Div|X, 0), // A /= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 2, - }, - { - desc: "Modulo immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 17), // A = 17 - Stmt(Alu|Mod|K, 7), // A %= 7 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 3, - }, - { - desc: "Modulo X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 17), // A = 17 - Stmt(Ldx|Imm|W, 7), // X = 7 - Stmt(Alu|Mod|X, 0), // A %= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 3, - }, - { - desc: "Arithmetic negation", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Alu|Neg, 0), // A = -A - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xffffffff, - }, - { - desc: "Bitwise OR with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|Or|K, 0xff0055aa), // A |= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff00ffff, - }, - { - desc: "Bitwise OR with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|Or|X, 0), // A |= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff00ffff, - }, - { - desc: "Bitwise AND with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|And|K, 0xff0055aa), // A &= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff000000, - }, - { - desc: "Bitwise AND with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|And|X, 0), // A &= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff000000, - }, - { - desc: "Bitwise XOR with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|Xor|K, 0xff0055aa), // A ^= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0x0000ffff, - }, - { - desc: "Bitwise XOR with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|Xor|X, 0), // A ^= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0x0000ffff, - }, - { - desc: "Left shift by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Alu|Lsh|K, 5), // A <<= 5 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 32, - }, - { - desc: "Left shift by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Ldx|Imm|W, 5), // X = 5 - Stmt(Alu|Lsh|X, 0), // A <<= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 32, - }, - { - desc: "Right shift by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff - Stmt(Alu|Rsh|K, 31), // A >>= 31 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 1, - }, - { - desc: "Right shift by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff - Stmt(Ldx|Imm|W, 31), // X = 31 - Stmt(Alu|Rsh|X, 0), // A >>= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 1, - }, - { - desc: "Unconditional jump", - insns: []linux.BPFInstruction{ - Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - }, - expectedRet: 1, - }, - { - desc: "Jump when A == immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A != immediate", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A == X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Ldx|Imm|W, 42), // X = 42 - Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A != X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A > immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jgt|K, 9, 1, 2), // if (A > 9) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A <= immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jgt|K, 10, 1, 2), // if (A > 10) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A > X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 9), // X = 9 - Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A <= X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 10), // X = 10 - Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A >= immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jge|K, 10, 1, 2), // if (A >= 10) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A < immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jge|K, 11, 1, 2), // if (A >= 11) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A >= X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 10), // X = 10 - Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A < X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 11), // X = 11 - Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A & immediate != 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff), // A = 0xff - Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A & immediate == 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xfe), // A = 0xfe - Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A & X != 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff), // A = 0xff - Stmt(Ldx|Imm|W, 0x101), // X = 0x101 - Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A & X == 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xfe), // A = 0xfe - Stmt(Ldx|Imm|W, 0x101), // X = 0x101 - Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - } { - p, err := Compile(test.insns) - if err != nil { - t.Errorf("%s: unexpected compilation error: %v", test.desc, err) - continue - } - ret, err := Exec(p, InputBytes{test.input, binary.BigEndian}) - if err != nil { - t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) - continue - } - if ret != test.expectedRet { - t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret) - } - } -} - -func TestSimpleFilter(t *testing.T) { - // Seccomp filter example given in Linux's - // Documentation/networking/filter.txt, translated to bytecode using the - // Linux kernel tree's tools/net/bpf_asm. - filter := []linux.BPFInstruction{ - {0x20, 0, 0, 0x00000004}, // ld [4] /* offsetof(struct seccomp_data, arch) */ - {0x15, 0, 11, 0xc000003e}, // jne #0xc000003e, bad /* AUDIT_ARCH_X86_64 */ - {0x20, 0, 0, 0000000000}, // ld [0] /* offsetof(struct seccomp_data, nr) */ - {0x15, 10, 0, 0x0000000f}, // jeq #15, good /* __NR_rt_sigreturn */ - {0x15, 9, 0, 0x000000e7}, // jeq #231, good /* __NR_exit_group */ - {0x15, 8, 0, 0x0000003c}, // jeq #60, good /* __NR_exit */ - {0x15, 7, 0, 0000000000}, // jeq #0, good /* __NR_read */ - {0x15, 6, 0, 0x00000001}, // jeq #1, good /* __NR_write */ - {0x15, 5, 0, 0x00000005}, // jeq #5, good /* __NR_fstat */ - {0x15, 4, 0, 0x00000009}, // jeq #9, good /* __NR_mmap */ - {0x15, 3, 0, 0x0000000e}, // jeq #14, good /* __NR_rt_sigprocmask */ - {0x15, 2, 0, 0x0000000d}, // jeq #13, good /* __NR_rt_sigaction */ - {0x15, 1, 0, 0x00000023}, // jeq #35, good /* __NR_nanosleep */ - {0x06, 0, 0, 0000000000}, // bad: ret #0 /* SECCOMP_RET_KILL */ - {0x06, 0, 0, 0x7fff0000}, // good: ret #0x7fff0000 /* SECCOMP_RET_ALLOW */ - } - p, err := Compile(filter) - if err != nil { - t.Fatalf("Unexpected compilation error: %v", err) - } - - for _, test := range []struct { - // desc is the test's description. - desc string - - // SeccompData is the input data. - data linux.SeccompData - - // expectedRet is the expected return value of the BPF program. - expectedRet uint32 - }{ - { - desc: "Invalid arch is rejected", - data: linux.SeccompData{Nr: 1 /* x86 exit */, Arch: 0x40000003 /* AUDIT_ARCH_I386 */}, - expectedRet: 0, - }, - { - desc: "Disallowed syscall is rejected", - data: linux.SeccompData{Nr: 105 /* __NR_setuid */, Arch: 0xc000003e}, - expectedRet: 0, - }, - { - desc: "Allowed syscall is indeed allowed", - data: linux.SeccompData{Nr: 231 /* __NR_exit_group */, Arch: 0xc000003e}, - expectedRet: 0x7fff0000, - }, - } { - ret, err := Exec(p, dataAsInput(&test.data)) - if err != nil { - t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) - continue - } - if ret != test.expectedRet { - t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret) - } - } -} - -// seccompData is equivalent to struct seccomp_data. -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// asInput converts a seccompData to a bpf.Input. -func dataAsInput(data *linux.SeccompData) Input { - return InputBytes{marshal.Marshal(data), hostarch.ByteOrder} -} diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go deleted file mode 100644 index 37f684f25..000000000 --- a/pkg/bpf/program_builder_test.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" -) - -func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error { - instructions, err := p.Instructions() - if err != nil { - return fmt.Errorf("Instructions() failed: %v", err) - } - got, err := DecodeInstructions(instructions) - if err != nil { - return fmt.Errorf("DecodeInstructions('instructions') failed: %v", err) - } - expectedDecoded, err := DecodeInstructions(expected) - if err != nil { - return fmt.Errorf("DecodeInstructions('expected') failed: %v", err) - } - if got != expectedDecoded { - return fmt.Errorf("DecodeInstructions() failed, expected: %q, got: %q", expectedDecoded, got) - } - return nil -} - -func TestProgramBuilderSimple(t *testing.T) { - p := NewProgramBuilder() - p.AddStmt(Ld+Abs+W, 10) - p.AddJump(Jmp+Ja, 10, 0, 0) - - expected := []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Jump(Jmp+Ja, 10, 0, 0), - } - - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -func TestProgramBuilderLabels(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 11, "label_1", 0) - p.AddJumpFalseLabel(Jmp+Jeq+K, 12, 0, "label_2") - p.AddJumpLabels(Jmp+Jeq+K, 13, "label_3", "label_4") - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 1) - if err := p.AddLabel("label_3"); err != nil { - t.Errorf("AddLabel(label_3) failed: %v", err) - } - p.AddJumpLabels(Jmp+Jeq+K, 14, "label_4", "label_5") - if err := p.AddLabel("label_2"); err != nil { - t.Errorf("AddLabel(label_2) failed: %v", err) - } - p.AddJumpLabels(Jmp+Jeq+K, 15, "label_4", "label_6") - if err := p.AddLabel("label_4"); err != nil { - t.Errorf("AddLabel(label_4) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 4) - if err := p.AddLabel("label_5"); err != nil { - t.Errorf("AddLabel(label_5) failed: %v", err) - } - if err := p.AddLabel("label_6"); err != nil { - t.Errorf("AddLabel(label_6) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 5) - - expected := []linux.BPFInstruction{ - Jump(Jmp+Jeq+K, 11, 2, 0), - Jump(Jmp+Jeq+K, 12, 0, 3), - Jump(Jmp+Jeq+K, 13, 1, 3), - Stmt(Ld+Abs+W, 1), - Jump(Jmp+Jeq+K, 14, 1, 2), - Jump(Jmp+Jeq+K, 15, 0, 1), - Stmt(Ld+Abs+W, 4), - Stmt(Ld+Abs+W, 5), - } - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } - // Calling validate()=>p.Instructions() again to make sure - // Instructions can be called multiple times without ruining - // the program. - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -func TestProgramBuilderMissingErrorTarget(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -func TestProgramBuilderLabelWithNoInstruction(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -// TestProgramBuilderUnusedLabel tests that adding an unused label doesn't -// cause program generation to fail. -func TestProgramBuilderUnusedLabel(t *testing.T) { - p := NewProgramBuilder() - p.AddStmt(Ld+Abs+W, 10) - p.AddJump(Jmp+Ja, 10, 0, 0) - - expected := []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Jump(Jmp+Ja, 10, 0, 0), - } - - if err := p.AddLabel("unused"); err != nil { - t.Errorf("AddLabel(unused) should have succeeded") - } - - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -// TestProgramBuilderBackwardsReference tests that including a backwards -// reference to a label in a program causes a failure. -func TestProgramBuilderBackwardsReference(t *testing.T) { - p := NewProgramBuilder() - if err := p.AddLabel("bw_label"); err != nil { - t.Errorf("failed to add label") - } - p.AddStmt(Ld+Abs+W, 10) - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "bw_label", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -func TestProgramBuilderLabelAddedTwice(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 0) - if err := p.AddLabel("label_1"); err == nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } -} - -func TestProgramBuilderJumpBackwards(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 0) - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD deleted file mode 100644 index 19cd28a32..000000000 --- a/pkg/buffer/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "buffer_list", - out = "buffer_list.go", - package = "buffer", - prefix = "buffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*buffer", - "Linker": "*buffer", - }, -) - -go_library( - name = "buffer", - srcs = [ - "buffer.go", - "buffer_list.go", - "pool.go", - "view.go", - "view_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/context", - "//pkg/log", - ], -) - -go_test( - name = "buffer_test", - size = "small", - srcs = [ - "buffer_test.go", - "pool_test.go", - "view_test.go", - ], - library = ":buffer", - deps = [ - "//pkg/state", - ], -) diff --git a/pkg/buffer/buffer_list.go b/pkg/buffer/buffer_list.go new file mode 100644 index 000000000..6b5bea3fc --- /dev/null +++ b/pkg/buffer/buffer_list.go @@ -0,0 +1,221 @@ +package buffer + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type bufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (bufferElementMapper) linkerFor(elem *buffer) *buffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type bufferList struct { + head *buffer + tail *buffer +} + +// Reset resets list l to the empty state. +func (l *bufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *bufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *bufferList) Front() *buffer { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *bufferList) Back() *buffer { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *bufferList) Len() (count int) { + for e := l.Front(); e != nil; e = (bufferElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *bufferList) PushFront(e *buffer) { + linker := bufferElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + bufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *bufferList) PushBack(e *buffer) { + linker := bufferElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + bufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *bufferList) PushBackList(m *bufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + bufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + bufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *bufferList) InsertAfter(b, e *buffer) { + bLinker := bufferElementMapper{}.linkerFor(b) + eLinker := bufferElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + bufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *bufferList) InsertBefore(a, e *buffer) { + aLinker := bufferElementMapper{}.linkerFor(a) + eLinker := bufferElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + bufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *bufferList) Remove(e *buffer) { + linker := bufferElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + bufferElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + bufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type bufferEntry struct { + next *buffer + prev *buffer +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *bufferEntry) Next() *buffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *bufferEntry) Prev() *buffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *bufferEntry) SetNext(elem *buffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *bufferEntry) SetPrev(elem *buffer) { + e.prev = elem +} diff --git a/pkg/buffer/buffer_state_autogen.go b/pkg/buffer/buffer_state_autogen.go new file mode 100644 index 000000000..aaa72b723 --- /dev/null +++ b/pkg/buffer/buffer_state_autogen.go @@ -0,0 +1,163 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (b *buffer) StateTypeName() string { + return "pkg/buffer.buffer" +} + +func (b *buffer) StateFields() []string { + return []string{ + "data", + "read", + "write", + "bufferEntry", + } +} + +func (b *buffer) beforeSave() {} + +// +checklocksignore +func (b *buffer) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.data) + stateSinkObject.Save(1, &b.read) + stateSinkObject.Save(2, &b.write) + stateSinkObject.Save(3, &b.bufferEntry) +} + +func (b *buffer) afterLoad() {} + +// +checklocksignore +func (b *buffer) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.data) + stateSourceObject.Load(1, &b.read) + stateSourceObject.Load(2, &b.write) + stateSourceObject.Load(3, &b.bufferEntry) +} + +func (l *bufferList) StateTypeName() string { + return "pkg/buffer.bufferList" +} + +func (l *bufferList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *bufferList) beforeSave() {} + +// +checklocksignore +func (l *bufferList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *bufferList) afterLoad() {} + +// +checklocksignore +func (l *bufferList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *bufferEntry) StateTypeName() string { + return "pkg/buffer.bufferEntry" +} + +func (e *bufferEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *bufferEntry) beforeSave() {} + +// +checklocksignore +func (e *bufferEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *bufferEntry) afterLoad() {} + +// +checklocksignore +func (e *bufferEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (p *pool) StateTypeName() string { + return "pkg/buffer.pool" +} + +func (p *pool) StateFields() []string { + return []string{ + "bufferSize", + "embeddedStorage", + } +} + +func (p *pool) beforeSave() {} + +// +checklocksignore +func (p *pool) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.bufferSize) + stateSinkObject.Save(1, &p.embeddedStorage) +} + +// +checklocksignore +func (p *pool) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.bufferSize) + stateSourceObject.LoadWait(1, &p.embeddedStorage) + stateSourceObject.AfterLoad(p.afterLoad) +} + +func (v *View) StateTypeName() string { + return "pkg/buffer.View" +} + +func (v *View) StateFields() []string { + return []string{ + "data", + "size", + "pool", + } +} + +func (v *View) beforeSave() {} + +// +checklocksignore +func (v *View) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + stateSinkObject.Save(0, &v.data) + stateSinkObject.Save(1, &v.size) + stateSinkObject.Save(2, &v.pool) +} + +func (v *View) afterLoad() {} + +// +checklocksignore +func (v *View) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.data) + stateSourceObject.Load(1, &v.size) + stateSourceObject.Load(2, &v.pool) +} + +func init() { + state.Register((*buffer)(nil)) + state.Register((*bufferList)(nil)) + state.Register((*bufferEntry)(nil)) + state.Register((*pool)(nil)) + state.Register((*View)(nil)) +} diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go deleted file mode 100644 index 32db841e4..000000000 --- a/pkg/buffer/buffer_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "testing" -) - -func TestBufferRemove(t *testing.T) { - sample := []byte("01234567") - - // Success cases - for _, tc := range []struct { - desc string - data []byte - rng Range - want []byte - }{ - { - desc: "empty slice", - }, - { - desc: "empty range", - data: sample, - want: sample, - }, - { - desc: "empty range with positive begin", - data: sample, - rng: Range{begin: 1, end: 1}, - want: sample, - }, - { - desc: "range at beginning", - data: sample, - rng: Range{begin: 0, end: 1}, - want: sample[1:], - }, - { - desc: "range in middle", - data: sample, - rng: Range{begin: 2, end: 4}, - want: []byte("014567"), - }, - { - desc: "range at end", - data: sample, - rng: Range{begin: 7, end: 8}, - want: sample[:7], - }, - { - desc: "range all", - data: sample, - rng: Range{begin: 0, end: 8}, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - var buf buffer - buf.initWithData(tc.data) - if ok := buf.Remove(tc.rng); !ok { - t.Errorf("buf.Remove(%#v) = false, want true", tc.rng) - } else if got := buf.ReadSlice(); !bytes.Equal(got, tc.want) { - t.Errorf("buf.ReadSlice() = %q, want %q", got, tc.want) - } - }) - } - - // Failure cases - for _, tc := range []struct { - desc string - data []byte - rng Range - }{ - { - desc: "begin out-of-range", - data: sample, - rng: Range{begin: -1, end: 4}, - }, - { - desc: "end out-of-range", - data: sample, - rng: Range{begin: 4, end: 9}, - }, - { - desc: "both out-of-range", - data: sample, - rng: Range{begin: -100, end: 100}, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - var buf buffer - buf.initWithData(tc.data) - if ok := buf.Remove(tc.rng); ok { - t.Errorf("buf.Remove(%#v) = true, want false", tc.rng) - } - }) - } -} diff --git a/pkg/buffer/buffer_unsafe_state_autogen.go b/pkg/buffer/buffer_unsafe_state_autogen.go new file mode 100644 index 000000000..5a5c40722 --- /dev/null +++ b/pkg/buffer/buffer_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package buffer diff --git a/pkg/buffer/pool_test.go b/pkg/buffer/pool_test.go deleted file mode 100644 index 8584bac89..000000000 --- a/pkg/buffer/pool_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "testing" -) - -func TestGetDefaultBufferSize(t *testing.T) { - var p pool - for i := 0; i < embeddedCount*2; i++ { - buf := p.get() - if got, want := len(buf.data), defaultBufferSize; got != want { - t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want) - } - } -} - -func TestGetCustomBufferSize(t *testing.T) { - const size = 100 - - var p pool - p.setBufferSize(size) - for i := 0; i < embeddedCount*2; i++ { - buf := p.get() - if got, want := len(buf.data), size; got != want { - t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want) - } - } -} - -func TestPut(t *testing.T) { - var p pool - buf := p.get() - p.put(buf) - if buf.data != nil { - t.Errorf("buf.data = %x, want nil", buf.data) - } -} diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go deleted file mode 100644 index 796efa240..000000000 --- a/pkg/buffer/view_test.go +++ /dev/null @@ -1,900 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "context" - "fmt" - "io" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/state" -) - -const bufferSize = defaultBufferSize - -func fillAppend(v *View, data []byte) { - v.Append(data) -} - -func fillAppendEnd(v *View, data []byte) { - v.Grow(bufferSize-1, false) - v.Append(data) - v.TrimFront(bufferSize - 1) -} - -func fillWriteFromReader(v *View, data []byte) { - b := bytes.NewBuffer(data) - v.WriteFromReader(b, int64(len(data))) -} - -func fillWriteFromReaderEnd(v *View, data []byte) { - v.Grow(bufferSize-1, false) - b := bytes.NewBuffer(data) - v.WriteFromReader(b, int64(len(data))) - v.TrimFront(bufferSize - 1) -} - -var fillFuncs = map[string]func(*View, []byte){ - "append": fillAppend, - "appendEnd": fillAppendEnd, - "writeFromReader": fillWriteFromReader, - "writeFromReaderEnd": fillWriteFromReaderEnd, -} - -func BenchmarkReadAt(b *testing.B) { - b.ReportAllocs() - var v View - v.Append(make([]byte, 100)) - - buf := make([]byte, 10) - for i := 0; i < b.N; i++ { - v.ReadAt(buf, 0) - } -} - -func BenchmarkWriteRead(b *testing.B) { - b.ReportAllocs() - var v View - sz := 1000 - wbuf := make([]byte, sz) - rbuf := bytes.NewBuffer(make([]byte, sz)) - for i := 0; i < b.N; i++ { - v.Append(wbuf) - rbuf.Reset() - v.ReadToWriter(rbuf, int64(sz)) - } -} - -func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) { - t.Helper() - d := make([]byte, n) - n, err := v.ReadAt(d, offset) - if n != len(wantStr) { - t.Errorf("got %d, want %d", n, len(wantStr)) - } - if err != wantErr { - t.Errorf("got err %v, want %v", err, wantErr) - } - if !bytes.Equal(d[:n], []byte(wantStr)) { - t.Errorf("got %q, want %q", string(d[:n]), wantStr) - } -} - -func TestView(t *testing.T) { - testCases := []struct { - name string - input string - output string - op func(*testing.T, *View) - }{ - // Preconditions. - { - name: "truncate-check", - input: "hello", - output: "hello", // Not touched. - op: func(t *testing.T, v *View) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Truncate(-1) did not panic") - } - }() - v.Truncate(-1) - }, - }, - { - name: "grow-check", - input: "hello", - output: "hello", // Not touched. - op: func(t *testing.T, v *View) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Grow(-1) did not panic") - } - }() - v.Grow(-1, false) - }, - }, - { - name: "advance-check", - input: "hello", - output: "", // Consumed. - op: func(t *testing.T, v *View) { - defer func() { - if r := recover(); r == nil { - t.Errorf("advanceRead(Size()+1) did not panic") - } - }() - v.advanceRead(v.Size() + 1) - }, - }, - - // Prepend. - { - name: "prepend", - input: "world", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Prepend([]byte("hello ")) - }, - }, - { - name: "prepend-backfill-full", - input: "hello world", - output: "jello world", - op: func(t *testing.T, v *View) { - v.TrimFront(1) - v.Prepend([]byte("j")) - }, - }, - { - name: "prepend-backfill-under", - input: "hello world", - output: "hola world", - op: func(t *testing.T, v *View) { - v.TrimFront(5) - v.Prepend([]byte("hola")) - }, - }, - { - name: "prepend-backfill-over", - input: "hello world", - output: "smello world", - op: func(t *testing.T, v *View) { - v.TrimFront(1) - v.Prepend([]byte("sm")) - }, - }, - { - name: "prepend-fill", - input: strings.Repeat("1", bufferSize-1), - output: "0" + strings.Repeat("1", bufferSize-1), - op: func(t *testing.T, v *View) { - v.Prepend([]byte("0")) - }, - }, - { - name: "prepend-overflow", - input: strings.Repeat("1", bufferSize), - output: "0" + strings.Repeat("1", bufferSize), - op: func(t *testing.T, v *View) { - v.Prepend([]byte("0")) - }, - }, - { - name: "prepend-multiple-buffers", - input: strings.Repeat("1", bufferSize-1), - output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1), - op: func(t *testing.T, v *View) { - v.Prepend([]byte(strings.Repeat("0", bufferSize*3))) - }, - }, - - // Append and write. - { - name: "append", - input: "hello", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Append([]byte(" world")) - }, - }, - { - name: "append-fill", - input: strings.Repeat("1", bufferSize-1), - output: strings.Repeat("1", bufferSize-1) + "0", - op: func(t *testing.T, v *View) { - v.Append([]byte("0")) - }, - }, - { - name: "append-overflow", - input: strings.Repeat("1", bufferSize), - output: strings.Repeat("1", bufferSize) + "0", - op: func(t *testing.T, v *View) { - v.Append([]byte("0")) - }, - }, - { - name: "append-multiple-buffers", - input: strings.Repeat("1", bufferSize-1), - output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3), - op: func(t *testing.T, v *View) { - v.Append([]byte(strings.Repeat("0", bufferSize*3))) - }, - }, - - // AppendOwned. - { - name: "append-owned", - input: "hello", - output: "hello world", - op: func(t *testing.T, v *View) { - b := []byte("Xworld") - v.AppendOwned(b) - b[0] = ' ' - }, - }, - - // Truncate. - { - name: "truncate", - input: "hello world", - output: "hello", - op: func(t *testing.T, v *View) { - v.Truncate(5) - }, - }, - { - name: "truncate-noop", - input: "hello world", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Truncate(v.Size() + 1) - }, - }, - { - name: "truncate-multiple-buffers", - input: strings.Repeat("1", bufferSize*2), - output: strings.Repeat("1", bufferSize*2-1), - op: func(t *testing.T, v *View) { - v.Truncate(bufferSize*2 - 1) - }, - }, - { - name: "truncate-multiple-buffers-to-one", - input: strings.Repeat("1", bufferSize*2), - output: "11111", - op: func(t *testing.T, v *View) { - v.Truncate(5) - }, - }, - - // TrimFront. - { - name: "trim", - input: "hello world", - output: "world", - op: func(t *testing.T, v *View) { - v.TrimFront(6) - }, - }, - { - name: "trim-too-large", - input: "hello world", - output: "", - op: func(t *testing.T, v *View) { - v.TrimFront(v.Size() + 1) - }, - }, - { - name: "trim-multiple-buffers", - input: strings.Repeat("1", bufferSize*2), - output: strings.Repeat("1", bufferSize*2-1), - op: func(t *testing.T, v *View) { - v.TrimFront(1) - }, - }, - { - name: "trim-multiple-buffers-to-one-buffer", - input: strings.Repeat("1", bufferSize*2), - output: "1", - op: func(t *testing.T, v *View) { - v.TrimFront(bufferSize*2 - 1) - }, - }, - - // Grow. - { - name: "grow", - input: "hello world", - output: "hello world", - op: func(t *testing.T, v *View) { - v.Grow(1, true) - }, - }, - { - name: "grow-from-zero", - output: strings.Repeat("\x00", 1024), - op: func(t *testing.T, v *View) { - v.Grow(1024, true) - }, - }, - { - name: "grow-from-non-zero", - input: strings.Repeat("1", bufferSize), - output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize), - op: func(t *testing.T, v *View) { - v.Grow(bufferSize*2, true) - }, - }, - - // Copy. - { - name: "copy", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { - other := v.Copy() - bs := other.Flatten() - want := []byte("hello") - if !bytes.Equal(bs, want) { - t.Errorf("expected %v, got %v", want, bs) - } - }, - }, - { - name: "copy-large", - input: strings.Repeat("1", bufferSize+1), - output: strings.Repeat("1", bufferSize+1), - op: func(t *testing.T, v *View) { - other := v.Copy() - bs := other.Flatten() - want := []byte(strings.Repeat("1", bufferSize+1)) - if !bytes.Equal(bs, want) { - t.Errorf("expected %v, got %v", want, bs) - } - }, - }, - - // Merge. - { - name: "merge", - input: "hello", - output: "hello world", - op: func(t *testing.T, v *View) { - var other View - other.Append([]byte(" world")) - v.Merge(&other) - if sz := other.Size(); sz != 0 { - t.Errorf("expected 0, got %d", sz) - } - }, - }, - { - name: "merge-large", - input: strings.Repeat("1", bufferSize+1), - output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1), - op: func(t *testing.T, v *View) { - var other View - other.Append([]byte(strings.Repeat("0", bufferSize+1))) - v.Merge(&other) - if sz := other.Size(); sz != 0 { - t.Errorf("expected 0, got %d", sz) - } - }, - }, - - // ReadAt. - { - name: "readat", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) }, - }, - { - name: "readat-long", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) }, - }, - { - name: "readat-short", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) }, - }, - { - name: "readat-offset", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) }, - }, - { - name: "readat-long-offset", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) }, - }, - { - name: "readat-short-offset", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) }, - }, - { - name: "readat-skip-all", - input: "hello", - output: "hello", - op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) }, - }, - { - name: "readat-second-buffer", - input: strings.Repeat("0", bufferSize+1) + "12", - output: strings.Repeat("0", bufferSize+1) + "12", - op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) }, - }, - { - name: "readat-second-buffer-end", - input: strings.Repeat("0", bufferSize+1) + "12", - output: strings.Repeat("0", bufferSize+1) + "12", - op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) }, - }, - } - - for _, tc := range testCases { - for fillName, fn := range fillFuncs { - t.Run(fillName+"/"+tc.name, func(t *testing.T) { - // Construct & fill the view. - var view View - fn(&view, []byte(tc.input)) - - // Run the operation. - if tc.op != nil { - tc.op(t, &view) - } - - // Flatten and validate. - out := view.Flatten() - if !bytes.Equal([]byte(tc.output), out) { - t.Errorf("expected %q, got %q", tc.output, string(out)) - } - - // Ensure the size is correct. - if len(out) != int(view.Size()) { - t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size()) - } - - // Calculate contents via apply. - var appliedOut []byte - view.Apply(func(b []byte) { - appliedOut = append(appliedOut, b...) - }) - if len(appliedOut) != len(out) { - t.Errorf("expected %d, got %d", len(out), len(appliedOut)) - } - if !bytes.Equal(appliedOut, out) { - t.Errorf("expected %v, got %v", out, appliedOut) - } - - // Calculate contents via ReadToWriter. - var b bytes.Buffer - n, err := view.ReadToWriter(&b, int64(len(out))) - if n != int64(len(out)) { - t.Errorf("expected %d, got %d", len(out), n) - } - if err != nil { - t.Errorf("expected nil, got %v", err) - } - if !bytes.Equal(b.Bytes(), out) { - t.Errorf("expected %v, got %v", out, b.Bytes()) - } - }) - } - } -} - -func TestViewPullUp(t *testing.T) { - for _, tc := range []struct { - desc string - inputs []string - offset int - length int - output string - failed bool - // lengths is the lengths of each buffer node after the pull up. - lengths []int - }{ - { - desc: "whole empty view", - }, - { - desc: "zero pull", - inputs: []string{"hello", " world"}, - lengths: []int{5, 6}, - }, - { - desc: "whole view", - inputs: []string{"hello", " world"}, - offset: 0, - length: 11, - output: "hello world", - lengths: []int{11}, - }, - { - desc: "middle to end aligned", - inputs: []string{"0123", "45678", "9abcd"}, - offset: 4, - length: 10, - output: "456789abcd", - lengths: []int{4, 10}, - }, - { - desc: "middle to end unaligned", - inputs: []string{"0123", "45678", "9abcd"}, - offset: 6, - length: 8, - output: "6789abcd", - lengths: []int{4, 10}, - }, - { - desc: "middle aligned", - inputs: []string{"0123", "45678", "9abcd", "efgh"}, - offset: 6, - length: 5, - output: "6789a", - lengths: []int{4, 10, 4}, - }, - - // Failed cases. - { - desc: "empty view - length too long", - offset: 0, - length: 1, - failed: true, - }, - { - desc: "empty view - offset too large", - offset: 1, - length: 1, - failed: true, - }, - { - desc: "length too long", - inputs: []string{"0123", "45678", "9abcd"}, - offset: 4, - length: 100, - failed: true, - lengths: []int{4, 5, 5}, - }, - { - desc: "offset too large", - inputs: []string{"0123", "45678", "9abcd"}, - offset: 100, - length: 1, - failed: true, - lengths: []int{4, 5, 5}, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - var v View - for _, s := range tc.inputs { - v.AppendOwned([]byte(s)) - } - - got, gotOk := v.PullUp(tc.offset, tc.length) - want, wantOk := []byte(tc.output), !tc.failed - if gotOk != wantOk || !bytes.Equal(got, want) { - t.Errorf("v.PullUp(%d, %d) = %q, %t; %q, %t", tc.offset, tc.length, got, gotOk, want, wantOk) - } - - var gotLengths []int - for buf := v.data.Front(); buf != nil; buf = buf.Next() { - gotLengths = append(gotLengths, buf.ReadSize()) - } - if !reflect.DeepEqual(gotLengths, tc.lengths) { - t.Errorf("lengths = %v; want %v", gotLengths, tc.lengths) - } - }) - } -} - -func TestViewRemove(t *testing.T) { - // Success cases - for _, tc := range []struct { - desc string - // before is the contents for each buffer node initially. - before []string - // after is the contents for each buffer node after removal. - after []string - offset int - length int - }{ - { - desc: "empty view", - }, - { - desc: "nothing removed", - before: []string{"hello", " world"}, - after: []string{"hello", " world"}, - }, - { - desc: "whole view", - before: []string{"hello", " world"}, - offset: 0, - length: 11, - }, - { - desc: "beginning to middle aligned", - before: []string{"0123", "45678", "9abcd"}, - after: []string{"9abcd"}, - offset: 0, - length: 9, - }, - { - desc: "beginning to middle unaligned", - before: []string{"0123", "45678", "9abcd"}, - after: []string{"678", "9abcd"}, - offset: 0, - length: 6, - }, - { - desc: "middle to end aligned", - before: []string{"0123", "45678", "9abcd"}, - after: []string{"0123"}, - offset: 4, - length: 10, - }, - { - desc: "middle to end unaligned", - before: []string{"0123", "45678", "9abcd"}, - after: []string{"0123", "45"}, - offset: 6, - length: 8, - }, - { - desc: "middle aligned", - before: []string{"0123", "45678", "9abcd"}, - after: []string{"0123", "9abcd"}, - offset: 4, - length: 5, - }, - { - desc: "middle unaligned", - before: []string{"0123", "45678", "9abcd"}, - after: []string{"0123", "4578", "9abcd"}, - offset: 6, - length: 1, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - var v View - for _, s := range tc.before { - v.AppendOwned([]byte(s)) - } - - if ok := v.Remove(tc.offset, tc.length); !ok { - t.Errorf("v.Remove(%d, %d) = false, want true", tc.offset, tc.length) - } - - var got []string - for buf := v.data.Front(); buf != nil; buf = buf.Next() { - got = append(got, string(buf.ReadSlice())) - } - if !reflect.DeepEqual(got, tc.after) { - t.Errorf("after = %v; want %v", got, tc.after) - } - }) - } - - // Failure cases - for _, tc := range []struct { - desc string - // before is the contents for each buffer node initially. - before []string - offset int - length int - }{ - { - desc: "offset out-of-range", - before: []string{"hello", " world"}, - offset: -1, - length: 3, - }, - { - desc: "length too long", - before: []string{"hello", " world"}, - offset: 0, - length: 12, - }, - { - desc: "length too long with positive offset", - before: []string{"hello", " world"}, - offset: 3, - length: 9, - }, - { - desc: "length negative", - before: []string{"hello", " world"}, - offset: 0, - length: -1, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - var v View - for _, s := range tc.before { - v.AppendOwned([]byte(s)) - } - if ok := v.Remove(tc.offset, tc.length); ok { - t.Errorf("v.Remove(%d, %d) = true, want false", tc.offset, tc.length) - } - }) - } -} - -func TestViewSubApply(t *testing.T) { - var v View - v.AppendOwned([]byte("0123")) - v.AppendOwned([]byte("45678")) - v.AppendOwned([]byte("9abcd")) - - data := []byte("0123456789abcd") - - for i := 0; i <= len(data); i++ { - for j := i; j <= len(data); j++ { - t.Run(fmt.Sprintf("SubApply(%d,%d)", i, j), func(t *testing.T) { - var got []byte - v.SubApply(i, j-i, func(b []byte) { - got = append(got, b...) - }) - if want := data[i:j]; !bytes.Equal(got, want) { - t.Errorf("got = %q; want %q", got, want) - } - }) - } - } -} - -func doSaveAndLoad(t *testing.T, toSave, toLoad *View) { - t.Helper() - var buf bytes.Buffer - ctx := context.Background() - if _, err := state.Save(ctx, &buf, toSave); err != nil { - t.Fatal("state.Save:", err) - } - if _, err := state.Load(ctx, bytes.NewReader(buf.Bytes()), toLoad); err != nil { - t.Fatal("state.Load:", err) - } -} - -func TestSaveRestoreViewEmpty(t *testing.T) { - var toSave View - var v View - doSaveAndLoad(t, &toSave, &v) - - if got := v.pool.avail; got != nil { - t.Errorf("pool is not in zero state: v.pool.avail = %v, want nil", got) - } - if got := v.Flatten(); len(got) != 0 { - t.Errorf("v.Flatten() = %x, want []", got) - } -} - -func TestSaveRestoreView(t *testing.T) { - // Create data that fits 2.5 slots. - data := bytes.Join([][]byte{ - bytes.Repeat([]byte{1, 2}, defaultBufferSize), - bytes.Repeat([]byte{3}, defaultBufferSize/2), - }, nil) - - var toSave View - toSave.Append(data) - - var v View - doSaveAndLoad(t, &toSave, &v) - - // Next available slot at index 3; 0-2 slot are used. - i := 3 - if got, want := &v.pool.avail[0], &v.pool.embeddedStorage[i]; got != want { - t.Errorf("next available buffer points to %p, want %p (&v.pool.embeddedStorage[%d])", got, want, i) - } - if got := v.Flatten(); !bytes.Equal(got, data) { - t.Errorf("v.Flatten() = %x, want %x", got, data) - } -} - -func TestRangeIntersect(t *testing.T) { - for _, tc := range []struct { - desc string - x, y, want Range - }{ - { - desc: "empty intersects empty", - }, - { - desc: "empty intersection", - x: Range{end: 10}, - y: Range{begin: 10, end: 20}, - }, - { - desc: "some intersection", - x: Range{begin: 5, end: 20}, - y: Range{end: 10}, - want: Range{begin: 5, end: 10}, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - if got := tc.x.Intersect(tc.y); got != tc.want { - t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.x, tc.y, got, tc.want) - } - if got := tc.y.Intersect(tc.x); got != tc.want { - t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.y, tc.x, got, tc.want) - } - }) - } -} - -func TestRangeOffset(t *testing.T) { - for _, tc := range []struct { - input Range - offset int - output Range - }{ - { - input: Range{}, - offset: 0, - output: Range{}, - }, - { - input: Range{}, - offset: -1, - output: Range{begin: -1, end: -1}, - }, - { - input: Range{begin: 10, end: 20}, - offset: -1, - output: Range{begin: 9, end: 19}, - }, - { - input: Range{begin: 10, end: 20}, - offset: 2, - output: Range{begin: 12, end: 22}, - }, - } { - if got := tc.input.Offset(tc.offset); got != tc.output { - t.Errorf("(%#v).Offset(%d) = %#v, want %#v", tc.input, tc.offset, got, tc.output) - } - } -} - -func TestRangeLen(t *testing.T) { - for _, tc := range []struct { - r Range - want int - }{ - {r: Range{}, want: 0}, - {r: Range{begin: 1, end: 1}, want: 0}, - {r: Range{begin: -1, end: -1}, want: 0}, - {r: Range{end: 10}, want: 10}, - {r: Range{begin: 5, end: 10}, want: 5}, - } { - if got := tc.r.Len(); got != tc.want { - t.Errorf("(%#v).Len() = %d, want %d", tc.r, got, tc.want) - } - } -} diff --git a/pkg/cleanup/BUILD b/pkg/cleanup/BUILD deleted file mode 100644 index 5c34b9872..000000000 --- a/pkg/cleanup/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "cleanup", - srcs = ["cleanup.go"], - visibility = ["//:sandbox"], - deps = [ - ], -) - -go_test( - name = "cleanup_test", - srcs = ["cleanup_test.go"], - library = ":cleanup", -) diff --git a/pkg/cleanup/cleanup_state_autogen.go b/pkg/cleanup/cleanup_state_autogen.go new file mode 100644 index 000000000..8373268fa --- /dev/null +++ b/pkg/cleanup/cleanup_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package cleanup diff --git a/pkg/cleanup/cleanup_test.go b/pkg/cleanup/cleanup_test.go deleted file mode 100644 index ab3d9ed95..000000000 --- a/pkg/cleanup/cleanup_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cleanup - -import "testing" - -func testCleanupHelper(clean, cleanAdd *bool, release bool) func() { - cu := Make(func() { - *clean = true - }) - cu.Add(func() { - *cleanAdd = true - }) - defer cu.Clean() - if release { - return cu.Release() - } - return nil -} - -func TestCleanup(t *testing.T) { - clean := false - cleanAdd := false - testCleanupHelper(&clean, &cleanAdd, false) - if !clean { - t.Fatalf("cleanup function was not called.") - } - if !cleanAdd { - t.Fatalf("added cleanup function was not called.") - } -} - -func TestRelease(t *testing.T) { - clean := false - cleanAdd := false - cleaner := testCleanupHelper(&clean, &cleanAdd, true) - - // Check that clean was not called after release. - if clean { - t.Fatalf("cleanup function was called.") - } - if cleanAdd { - t.Fatalf("added cleanup function was called.") - } - - // Call the cleaner function and check that both cleanup functions are called. - cleaner() - if !clean { - t.Fatalf("cleanup function was not called.") - } - if !cleanAdd { - t.Fatalf("added cleanup function was not called.") - } -} diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD deleted file mode 100644 index 70018cf18..000000000 --- a/pkg/compressio/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "compressio", - srcs = ["compressio.go"], - visibility = ["//:sandbox"], - deps = ["//pkg/sync"], -) - -go_test( - name = "compressio_test", - size = "medium", - srcs = ["compressio_test.go"], - library = ":compressio", -) diff --git a/pkg/compressio/compressio_state_autogen.go b/pkg/compressio/compressio_state_autogen.go new file mode 100644 index 000000000..c47e0dd17 --- /dev/null +++ b/pkg/compressio/compressio_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package compressio diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go deleted file mode 100644 index 86dc47e44..000000000 --- a/pkg/compressio/compressio_test.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compressio - -import ( - "bytes" - "compress/flate" - "encoding/base64" - "fmt" - "io" - "math/rand" - "runtime" - "testing" - "time" -) - -type harness interface { - Errorf(format string, v ...interface{}) - Fatalf(format string, v ...interface{}) - Logf(format string, v ...interface{}) -} - -func initTest(t harness, size int) []byte { - // Set number of processes to number of CPUs. - runtime.GOMAXPROCS(runtime.NumCPU()) - - // Construct synthetic data. We do this by encoding random data with - // base64. This gives a high level of entropy, but still quite a bit of - // structure, to give reasonable compression ratios (~75%). - var buf bytes.Buffer - bufW := base64.NewEncoder(base64.RawStdEncoding, &buf) - bufR := rand.New(rand.NewSource(0)) - if _, err := io.CopyN(bufW, bufR, int64(size)); err != nil { - t.Fatalf("unable to seed random data: %v", err) - } - return buf.Bytes() -} - -type testOpts struct { - Name string - Data []byte - NewWriter func(*bytes.Buffer) (io.Writer, error) - NewReader func(*bytes.Buffer) (io.Reader, error) - PreCompress func() - PostCompress func() - PreDecompress func() - PostDecompress func() - CompressIters int - DecompressIters int - CorruptData bool -} - -func doTest(t harness, opts testOpts) { - // Compress. - var compressed bytes.Buffer - compressionStartTime := time.Now() - if opts.PreCompress != nil { - opts.PreCompress() - } - if opts.CompressIters <= 0 { - opts.CompressIters = 1 - } - for i := 0; i < opts.CompressIters; i++ { - compressed.Reset() - w, err := opts.NewWriter(&compressed) - if err != nil { - t.Errorf("%s: NewWriter got err %v, expected nil", opts.Name, err) - } - if _, err := io.Copy(w, bytes.NewBuffer(opts.Data)); err != nil { - t.Errorf("%s: compress got err %v, expected nil", opts.Name, err) - return - } - closer, ok := w.(io.Closer) - if ok { - if err := closer.Close(); err != nil { - t.Errorf("%s: got err %v, expected nil", opts.Name, err) - return - } - } - } - if opts.PostCompress != nil { - opts.PostCompress() - } - compressionTime := time.Since(compressionStartTime) - compressionRatio := float32(compressed.Len()) / float32(len(opts.Data)) - - // Decompress. - var decompressed bytes.Buffer - decompressionStartTime := time.Now() - if opts.PreDecompress != nil { - opts.PreDecompress() - } - if opts.DecompressIters <= 0 { - opts.DecompressIters = 1 - } - if opts.CorruptData { - b := compressed.Bytes() - b[rand.Intn(len(b))]++ - } - for i := 0; i < opts.DecompressIters; i++ { - decompressed.Reset() - r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes())) - if err != nil { - if opts.CorruptData { - continue - } - t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err) - return - } - if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData { - t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err) - return - } - } - if opts.PostDecompress != nil { - opts.PostDecompress() - } - decompressionTime := time.Since(decompressionStartTime) - - if opts.CorruptData { - return - } - - // Verify. - if decompressed.Len() != len(opts.Data) { - t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data)) - } - if !bytes.Equal(opts.Data, decompressed.Bytes()) { - t.Errorf("%s: got mismatch, expected match", opts.Name) - if len(opts.Data) < 32 { // Don't flood the logs. - t.Errorf("got %v, expected %v", decompressed.Bytes(), opts.Data) - } - } - - t.Logf("%s: compression time %v, ratio %2.2f, decompression time %v", - opts.Name, compressionTime, compressionRatio, decompressionTime) -} - -var hashKey = []byte("01234567890123456789012345678901") - -func TestCompress(t *testing.T) { - rand.Seed(time.Now().Unix()) - - var ( - data = initTest(t, 10*1024*1024) - data0 = data[:0] - data1 = data[:1] - data2 = data[:11] - data3 = data[:16] - data4 = data[:] - ) - - for _, data := range [][]byte{data0, data1, data2, data3, data4} { - for _, blockSize := range []uint32{1, 4, 1024, 4 * 1024, 16 * 1024} { - // Skip annoying tests; they just take too long. - if blockSize <= 16 && len(data) > 16 { - continue - } - - for _, key := range [][]byte{nil, hashKey} { - for _, corruptData := range []bool{false, true} { - if key == nil && corruptData { - // No need to test corrupt data - // case when not doing hashing. - continue - } - // Do the compress test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, key, blockSize, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b, key) - }, - CorruptData: corruptData, - }) - } - } - } - - // Do the vanilla test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return flate.NewWriter(b, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return flate.NewReader(b), nil - }, - }) - } -} - -const ( - benchDataSize = 600 * 1024 * 1024 -) - -func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) { - b.StopTimer() - b.SetBytes(benchDataSize) - data := initTest(b, benchDataSize) - compIters := b.N - decompIters := b.N - if compress { - decompIters = 0 - } else { - compIters = 0 - } - key := hashKey - if !hash { - key = nil - } - doTest(b, testOpts{ - Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize), - Data: data, - PreCompress: b.StartTimer, - PostCompress: b.StopTimer, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, key, blockSize, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b, key) - }, - CompressIters: compIters, - DecompressIters: decompIters, - }) -} - -func BenchmarkCompressNoHash64K(b *testing.B) { - benchmark(b, true, false, 64*1024) -} - -func BenchmarkCompressHash64K(b *testing.B) { - benchmark(b, true, true, 64*1024) -} - -func BenchmarkDecompressNoHash64K(b *testing.B) { - benchmark(b, false, false, 64*1024) -} - -func BenchmarkDecompressHash64K(b *testing.B) { - benchmark(b, false, true, 64*1024) -} - -func BenchmarkCompressNoHash1M(b *testing.B) { - benchmark(b, true, false, 1024*1024) -} - -func BenchmarkCompressHash1M(b *testing.B) { - benchmark(b, true, true, 1024*1024) -} - -func BenchmarkDecompressNoHash1M(b *testing.B) { - benchmark(b, false, false, 1024*1024) -} - -func BenchmarkDecompressHash1M(b *testing.B) { - benchmark(b, false, true, 1024*1024) -} - -func BenchmarkCompressNoHash16M(b *testing.B) { - benchmark(b, true, false, 16*1024*1024) -} - -func BenchmarkCompressHash16M(b *testing.B) { - benchmark(b, true, true, 16*1024*1024) -} - -func BenchmarkDecompressNoHash16M(b *testing.B) { - benchmark(b, false, false, 16*1024*1024) -} - -func BenchmarkDecompressHash16M(b *testing.B) { - benchmark(b, false, true, 16*1024*1024) -} diff --git a/pkg/context/BUILD b/pkg/context/BUILD deleted file mode 100644 index f33e23bf7..000000000 --- a/pkg/context/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - srcs = ["context.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - ], -) diff --git a/pkg/context/context_state_autogen.go b/pkg/context/context_state_autogen.go new file mode 100644 index 000000000..fdc3c9fbb --- /dev/null +++ b/pkg/context/context_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package context diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD deleted file mode 100644 index 1b9e10ee7..000000000 --- a/pkg/control/client/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "client", - srcs = [ - "client.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/unet", - "//pkg/urpc", - ], -) diff --git a/pkg/control/client/client_state_autogen.go b/pkg/control/client/client_state_autogen.go new file mode 100644 index 000000000..9872f1107 --- /dev/null +++ b/pkg/control/client/client_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package client diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD deleted file mode 100644 index 002d2ef44..000000000 --- a/pkg/control/server/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "server", - srcs = ["server.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/unet", - "//pkg/urpc", - ], -) diff --git a/pkg/control/server/server_state_autogen.go b/pkg/control/server/server_state_autogen.go new file mode 100644 index 000000000..c236b8da5 --- /dev/null +++ b/pkg/control/server/server_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package server diff --git a/pkg/coverage/BUILD b/pkg/coverage/BUILD deleted file mode 100644 index ace5895f8..000000000 --- a/pkg/coverage/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "coverage", - srcs = ["coverage.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/hostarch", - "//pkg/sync", - "@io_bazel_rules_go//go/tools/coverdata", - ], -) diff --git a/pkg/coverage/coverage_state_autogen.go b/pkg/coverage/coverage_state_autogen.go new file mode 100644 index 000000000..8ed6850af --- /dev/null +++ b/pkg/coverage/coverage_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package coverage diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD deleted file mode 100644 index c0c864902..000000000 --- a/pkg/cpuid/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "cpuid", - srcs = [ - "cpu_amd64.s", - "cpuid.go", - "cpuid_arm64.go", - "cpuid_x86.go", - ], - visibility = ["//:sandbox"], - deps = ["//pkg/log"], -) - -go_test( - name = "cpuid_test", - size = "small", - srcs = [ - "cpuid_arm64_test.go", - "cpuid_x86_test.go", - ], - library = ":cpuid", -) - -go_test( - name = "cpuid_parse_test", - size = "small", - srcs = [ - "cpuid_parse_x86_test.go", - ], - library = ":cpuid", - tags = ["manual"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/cpuid/cpuid_arm64_state_autogen.go b/pkg/cpuid/cpuid_arm64_state_autogen.go new file mode 100644 index 000000000..eb91e6d97 --- /dev/null +++ b/pkg/cpuid/cpuid_arm64_state_autogen.go @@ -0,0 +1,54 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package cpuid + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fs *FeatureSet) StateTypeName() string { + return "pkg/cpuid.FeatureSet" +} + +func (fs *FeatureSet) StateFields() []string { + return []string{ + "Set", + "CPUImplementer", + "CPUArchitecture", + "CPUVariant", + "CPUPartnum", + "CPURevision", + } +} + +func (fs *FeatureSet) beforeSave() {} + +// +checklocksignore +func (fs *FeatureSet) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Set) + stateSinkObject.Save(1, &fs.CPUImplementer) + stateSinkObject.Save(2, &fs.CPUArchitecture) + stateSinkObject.Save(3, &fs.CPUVariant) + stateSinkObject.Save(4, &fs.CPUPartnum) + stateSinkObject.Save(5, &fs.CPURevision) +} + +func (fs *FeatureSet) afterLoad() {} + +// +checklocksignore +func (fs *FeatureSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Set) + stateSourceObject.Load(1, &fs.CPUImplementer) + stateSourceObject.Load(2, &fs.CPUArchitecture) + stateSourceObject.Load(3, &fs.CPUVariant) + stateSourceObject.Load(4, &fs.CPUPartnum) + stateSourceObject.Load(5, &fs.CPURevision) +} + +func init() { + state.Register((*FeatureSet)(nil)) +} diff --git a/pkg/cpuid/cpuid_arm64_test.go b/pkg/cpuid/cpuid_arm64_test.go deleted file mode 100644 index 16b1c064a..000000000 --- a/pkg/cpuid/cpuid_arm64_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 - -package cpuid - -import ( - "testing" -) - -var justFP = &FeatureSet{ - Set: map[Feature]bool{ - ARM64FeatureFP: true, - }} - -func TestHostFeatureSet(t *testing.T) { - hostFeatures := HostFeatureSet() - if len(hostFeatures.Set) == 0 { - t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) - } -} - -func TestHasFeature(t *testing.T) { - if !justFP.HasFeature(ARM64FeatureFP) { - t.Errorf("HasFeature failed, %v should contain %v", justFP, ARM64FeatureFP) - } - - if justFP.HasFeature(ARM64FeatureSM3) { - t.Errorf("HasFeature failed, %v should not contain %v", justFP, ARM64FeatureSM3) - } -} - -func TestFeatureFromString(t *testing.T) { - f, ok := FeatureFromString("asimd") - if f != ARM64FeatureASIMD || !ok { - t.Errorf("got %v want asimd", f) - } - - f, ok = FeatureFromString("bad") - if ok { - t.Errorf("got %v want nothing", f) - } -} diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go deleted file mode 100644 index 36dd20552..000000000 --- a/pkg/cpuid/cpuid_parse_x86_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build 386 || amd64 -// +build 386 amd64 - -package cpuid - -import ( - "fmt" - "io/ioutil" - "regexp" - "strconv" - "strings" - "testing" - - "golang.org/x/sys/unix" -) - -func kernelVersion() (int, int, error) { - var u unix.Utsname - if err := unix.Uname(&u); err != nil { - return 0, 0, err - } - - var sb strings.Builder - for _, b := range u.Release { - if b == 0 { - break - } - sb.WriteByte(byte(b)) - } - - s := strings.Split(sb.String(), ".") - if len(s) < 2 { - return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", sb.String()) - } - - major, err := strconv.Atoi(s[0]) - if err != nil { - return 0, 0, fmt.Errorf("error parsing major version %q in %q: %w", s[0], sb.String(), err) - } - - minor, err := strconv.Atoi(s[1]) - if err != nil { - return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %w", s[1], sb.String(), err) - } - - return major, minor, nil -} - -// TestHostFeatureFlags tests that all features detected by HostFeatureSet are -// on the host. -// -// It does *not* verify that all features reported by the host are detected by -// HostFeatureSet. -// -// i.e., test that HostFeatureSet is a subset of the host features. -func TestHostFeatureFlags(t *testing.T) { - cpuinfoBytes, _ := ioutil.ReadFile("/proc/cpuinfo") - cpuinfo := string(cpuinfoBytes) - t.Logf("Host cpu info:\n%s", cpuinfo) - - major, minor, err := kernelVersion() - if err != nil { - t.Fatalf("Unable to parse kernel version: %v", err) - } - - re := regexp.MustCompile(`(?m)^flags\s+: (.*)$`) - m := re.FindStringSubmatch(cpuinfo) - if len(m) != 2 { - t.Fatalf("Unable to extract flags from %q", cpuinfo) - } - - cpuinfoFlags := make(map[string]struct{}) - for _, f := range strings.Split(m[1], " ") { - cpuinfoFlags[f] = struct{}{} - } - - fs := HostFeatureSet() - - // All features have a string and appear in host cpuinfo. - for f := range fs.Set { - name := f.flagString(false) - if name == "" { - t.Errorf("Non-parsable feature: %v", f) - } - - // Special cases not consistently visible. We don't mind if - // they are exposed in earlier versions. - switch { - // Block 0. - case f == X86FeatureSDBG && (major < 4 || major == 4 && minor < 3): - // SDBG only exposed in - // b1c599b8ff80ea79b9f8277a3f9f36a7b0cfedce (4.3). - continue - // Block 2. - case f == X86FeatureRDT && (major < 4 || major == 4 && minor < 10): - // RDT only exposed in - // 4ab1586488cb56ed8728e54c4157cc38646874d9 (4.10). - continue - // Block 3. - case f == X86FeatureAVX512VBMI && (major < 4 || major == 4 && minor < 10): - // AVX512VBMI only exposed in - // a8d9df5a509a232a959e4ef2e281f7ecd77810d6 (4.10). - continue - case f == X86FeatureUMIP && (major < 4 || major == 4 && minor < 15): - // UMIP only exposed in - // 3522c2a6a4f341058b8291326a945e2a2d2aaf55 (4.15). - continue - case f == X86FeaturePKU && (major < 4 || major == 4 && minor < 9): - // PKU only exposed in - // dfb4a70f20c5b3880da56ee4c9484bdb4e8f1e65 (4.9). - continue - // Block 4. - case f == X86FeatureXSAVES && (major < 4 || major == 4 && minor < 8): - // XSAVES only exposed in - // b8be15d588060a03569ac85dc4a0247460988f5b (4.8). - continue - // Block 5. - case f == X86FeaturePERFCTR_LLC && (major < 4 || major == 4 && minor < 14): - // PERFCTR_LLC renamed in - // 910448bbed066ab1082b510eef1ae61bb792d854 (4.14). - continue - } - - hidden := f.flagString(true) == "" - _, ok := cpuinfoFlags[name] - if hidden && ok { - t.Errorf("Unexpectedly hidden flag: %v", f) - } else if !hidden && !ok { - t.Errorf("Non-native flag: %v", f) - } - } -} diff --git a/pkg/cpuid/cpuid_state_autogen.go b/pkg/cpuid/cpuid_state_autogen.go new file mode 100644 index 000000000..86206a6bf --- /dev/null +++ b/pkg/cpuid/cpuid_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package cpuid diff --git a/pkg/cpuid/cpuid_x86_state_autogen.go b/pkg/cpuid/cpuid_x86_state_autogen.go new file mode 100644 index 000000000..ec639ae38 --- /dev/null +++ b/pkg/cpuid/cpuid_x86_state_autogen.go @@ -0,0 +1,116 @@ +// automatically generated by stateify. + +//go:build 386 || amd64 +// +build 386 amd64 + +package cpuid + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (c *Cache) StateTypeName() string { + return "pkg/cpuid.Cache" +} + +func (c *Cache) StateFields() []string { + return []string{ + "Level", + "Type", + "FullyAssociative", + "Partitions", + "Ways", + "Sets", + "InvalidateHierarchical", + "Inclusive", + "DirectMapped", + } +} + +func (c *Cache) beforeSave() {} + +// +checklocksignore +func (c *Cache) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.Level) + stateSinkObject.Save(1, &c.Type) + stateSinkObject.Save(2, &c.FullyAssociative) + stateSinkObject.Save(3, &c.Partitions) + stateSinkObject.Save(4, &c.Ways) + stateSinkObject.Save(5, &c.Sets) + stateSinkObject.Save(6, &c.InvalidateHierarchical) + stateSinkObject.Save(7, &c.Inclusive) + stateSinkObject.Save(8, &c.DirectMapped) +} + +func (c *Cache) afterLoad() {} + +// +checklocksignore +func (c *Cache) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.Level) + stateSourceObject.Load(1, &c.Type) + stateSourceObject.Load(2, &c.FullyAssociative) + stateSourceObject.Load(3, &c.Partitions) + stateSourceObject.Load(4, &c.Ways) + stateSourceObject.Load(5, &c.Sets) + stateSourceObject.Load(6, &c.InvalidateHierarchical) + stateSourceObject.Load(7, &c.Inclusive) + stateSourceObject.Load(8, &c.DirectMapped) +} + +func (fs *FeatureSet) StateTypeName() string { + return "pkg/cpuid.FeatureSet" +} + +func (fs *FeatureSet) StateFields() []string { + return []string{ + "Set", + "VendorID", + "ExtendedFamily", + "ExtendedModel", + "ProcessorType", + "Family", + "Model", + "SteppingID", + "Caches", + "CacheLine", + } +} + +func (fs *FeatureSet) beforeSave() {} + +// +checklocksignore +func (fs *FeatureSet) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Set) + stateSinkObject.Save(1, &fs.VendorID) + stateSinkObject.Save(2, &fs.ExtendedFamily) + stateSinkObject.Save(3, &fs.ExtendedModel) + stateSinkObject.Save(4, &fs.ProcessorType) + stateSinkObject.Save(5, &fs.Family) + stateSinkObject.Save(6, &fs.Model) + stateSinkObject.Save(7, &fs.SteppingID) + stateSinkObject.Save(8, &fs.Caches) + stateSinkObject.Save(9, &fs.CacheLine) +} + +func (fs *FeatureSet) afterLoad() {} + +// +checklocksignore +func (fs *FeatureSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Set) + stateSourceObject.Load(1, &fs.VendorID) + stateSourceObject.Load(2, &fs.ExtendedFamily) + stateSourceObject.Load(3, &fs.ExtendedModel) + stateSourceObject.Load(4, &fs.ProcessorType) + stateSourceObject.Load(5, &fs.Family) + stateSourceObject.Load(6, &fs.Model) + stateSourceObject.Load(7, &fs.SteppingID) + stateSourceObject.Load(8, &fs.Caches) + stateSourceObject.Load(9, &fs.CacheLine) +} + +func init() { + state.Register((*Cache)(nil)) + state.Register((*FeatureSet)(nil)) +} diff --git a/pkg/cpuid/cpuid_x86_test.go b/pkg/cpuid/cpuid_x86_test.go deleted file mode 100644 index 92a2d9f81..000000000 --- a/pkg/cpuid/cpuid_x86_test.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build 386 || amd64 -// +build 386 amd64 - -package cpuid - -import ( - "testing" -) - -// These are the default values of various FeatureSet fields. -const ( - defaultVendorID = "GenuineIntel" - - // These processor signature defaults are derived from the values - // listed in Intel Application Note 485 for i7/Xeon processors. - defaultExtFamily uint8 = 0 - defaultExtModel uint8 = 1 - defaultType uint8 = 0 - defaultFamily uint8 = 0x06 - defaultModel uint8 = 0x0a - defaultSteppingID uint8 = 0 -) - -// newEmptyFeatureSet creates a new FeatureSet with a sensible default model and no features. -func newEmptyFeatureSet() *FeatureSet { - return &FeatureSet{ - Set: make(map[Feature]bool), - VendorID: defaultVendorID, - ExtendedFamily: defaultExtFamily, - ExtendedModel: defaultExtModel, - ProcessorType: defaultType, - Family: defaultFamily, - Model: defaultModel, - SteppingID: defaultSteppingID, - } -} - -var justFPU = &FeatureSet{ - Set: map[Feature]bool{ - X86FeatureFPU: true, - }} - -var justFPUandPAE = &FeatureSet{ - Set: map[Feature]bool{ - X86FeatureFPU: true, - X86FeaturePAE: true, - }} - -func TestSubtract(t *testing.T) { - if diff := justFPU.Subtract(justFPUandPAE); diff != nil { - t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff) - } - - if justFPUandPAE.Subtract(justFPU) == nil { - t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE) - } -} - -// TODO(b/73346484): Run this test on a very old platform, and make sure more -// bits are enabled than just FPU and PAE. This test currently may not detect -// if HostFeatureSet gives back junk bits. -func TestHostFeatureSet(t *testing.T) { - hostFeatures := HostFeatureSet() - if justFPUandPAE.Subtract(hostFeatures) != nil { - t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) - } -} - -func TestHasFeature(t *testing.T) { - if !justFPU.HasFeature(X86FeatureFPU) { - t.Errorf("HasFeature failed, %v should contain %v", justFPU, X86FeatureFPU) - } - - if justFPU.HasFeature(X86FeatureAVX) { - t.Errorf("HasFeature failed, %v should not contain %v", justFPU, X86FeatureAVX) - } -} - -// Note: these tests are aware of and abuse internal details of FeatureSets. -// Users of FeatureSets should not depend on this. -func TestAdd(t *testing.T) { - // Test a basic insertion into the FeatureSet. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureCLFSH) - if len(testFeatures.Set) != 1 { - t.Errorf("Got length %v want 1", len(testFeatures.Set)) - } - - if !testFeatures.HasFeature(X86FeatureCLFSH) { - t.Errorf("Add failed, got %v want set with %v", testFeatures, X86FeatureCLFSH) - } - - // Test that duplicates are ignored. - testFeatures.Add(X86FeatureCLFSH) - if len(testFeatures.Set) != 1 { - t.Errorf("Got length %v, want 1", len(testFeatures.Set)) - } -} - -func TestRemove(t *testing.T) { - // Try removing the last feature. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureFPU) - testFeatures.Add(X86FeaturePAE) - testFeatures.Remove(X86FeaturePAE) - if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 || testFeatures.HasFeature(X86FeaturePAE) { - t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU) - } - - // Try removing a feature not in the set. - testFeatures.Remove(X86FeatureRDRAND) - if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 { - t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU) - } -} - -func TestFeatureFromString(t *testing.T) { - f, ok := FeatureFromString("avx") - if f != X86FeatureAVX || !ok { - t.Errorf("got %v want avx", f) - } - - f, ok = FeatureFromString("bad") - if ok { - t.Errorf("got %v want nothing", f) - } -} - -// This tests function 0 (eax=0), which returns the vendor ID and highest cpuid -// function reported to be available. -func TestEmulateIDVendorAndLength(t *testing.T) { - testFeatures := newEmptyFeatureSet() - - ax, bx, cx, dx := testFeatures.EmulateID(0, 0) - wantEax := uint32(0xd) // Highest supported cpuid function. - - // These magical constants are the characters of "GenuineIntel". - // See Intel AN485 for a reference on why they are laid out like this. - wantEbx := uint32(0x756e6547) - wantEcx := uint32(0x6c65746e) - wantEdx := uint32(0x49656e69) - if wantEax != ax { - t.Errorf("highest function failed, got %x want %x", ax, wantEax) - } - - if wantEbx != bx || wantEcx != cx || wantEdx != dx { - t.Errorf("vendor string emulation failed, bx:cx:dx, got %x:%x:%x want %x:%x:%x", bx, cx, dx, wantEbx, wantEcx, wantEdx) - } -} - -func TestEmulateIDBasicFeatures(t *testing.T) { - // Make a minimal test feature set. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureCLFSH) - testFeatures.Add(X86FeatureAVX) - testFeatures.CacheLine = 64 - - ax, bx, cx, dx := testFeatures.EmulateID(1, 0) - ECXAVXBit := uint32(1 << uint(X86FeatureAVX)) - EDXCLFlushBit := uint32(1 << uint(X86FeatureCLFSH-32)) // We adjust by 32 since it's in block 1. - - if EDXCLFlushBit&dx == 0 || dx&^EDXCLFlushBit != 0 { - t.Errorf("EmulateID failed, got feature bits %x want %x", dx, testFeatures.blockMask(1)) - } - - if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 { - t.Errorf("EmulateID failed, got feature bits %x want %x", cx, testFeatures.blockMask(0)) - } - - // Default signature bits, based on values for i7/Xeon. - // See Intel AN485 for information on stepping/model bits. - defaultSignature := uint32(0x000106a0) - if defaultSignature != ax { - t.Errorf("EmulateID stepping emulation failed, got %x want %x", ax, defaultSignature) - } - - clflushSizeInfo := uint32(8 << 8) - if clflushSizeInfo != bx { - t.Errorf("EmulateID bx emulation failed, got %x want %x", bx, clflushSizeInfo) - } -} - -func TestEmulateIDExtendedFeatures(t *testing.T) { - // Make a minimal test feature set, one bit in each extended feature word. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureSMEP) - testFeatures.Add(X86FeatureAVX512VBMI) - - ax, bx, cx, dx := testFeatures.EmulateID(7, 0) - EBXSMEPBit := uint32(1 << uint(X86FeatureSMEP-2*32)) // Adjust by 2*32 since SMEP is a block 2 feature. - ECXAVXBit := uint32(1 << uint(X86FeatureAVX512VBMI-3*32)) // We adjust by 3*32 since it's a block 3 feature. - - // Test that the desired bit is set and no other bits are set. - if EBXSMEPBit&bx == 0 || bx&^EBXSMEPBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", bx, testFeatures.blockMask(2)) - } - - if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", cx, testFeatures.blockMask(3)) - } - - if ax != 0 || dx != 0 { - t.Errorf("extended feature emulation failed, ax:dx, got %x:%x want 0:0", ax, dx) - } - - // Check that no subleaves other than 0 do anything. - ax, bx, cx, dx = testFeatures.EmulateID(7, 1) - if ax != 0 || bx != 0 || cx != 0 || dx != 0 { - t.Errorf("extended feature emulation failed, got %x:%x:%x:%x want 0:0", ax, bx, cx, dx) - } - -} - -// Checks that the expected extended features are available via cpuid functions -// 0x80000000 and up. -func TestEmulateIDExtended(t *testing.T) { - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureSYSCALL) - EDXSYSCALLBit := uint32(1 << uint(X86FeatureSYSCALL-6*32)) // Adjust by 6*32 since SYSCALL is a block 6 feature. - - ax, bx, cx, dx := testFeatures.EmulateID(0x80000000, 0) - if ax != 0x80000001 || bx != 0 || cx != 0 || dx != 0 { - t.Errorf("EmulateID extended emulation failed, ax:bx:cx:dx, got %x:%x:%x:%x want 0x80000001:0:0:0", ax, bx, cx, dx) - } - - _, _, _, dx = testFeatures.EmulateID(0x80000001, 0) - if EDXSYSCALLBit&dx == 0 || dx&^EDXSYSCALLBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", dx, testFeatures.blockMask(6)) - } -} diff --git a/pkg/crypto/BUILD b/pkg/crypto/BUILD deleted file mode 100644 index 08fa772ca..000000000 --- a/pkg/crypto/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "crypto", - srcs = [ - "crypto.go", - "crypto_stdlib.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/crypto/crypto.go b/pkg/crypto/crypto.go deleted file mode 100644 index b26b55d37..000000000 --- a/pkg/crypto/crypto.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package crypto wraps crypto primitives. -package crypto diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go deleted file mode 100644 index 69e867386..000000000 --- a/pkg/crypto/crypto_stdlib.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build go1.1 -// +build go1.1 - -package crypto - -import ( - "crypto/ecdsa" - "crypto/sha512" - "math/big" -) - -// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the -// public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { - return ecdsa.Verify(pub, hash, r, s), nil -} - -// SumSha384 returns the SHA384 checksum of the data. -func SumSha384(data []byte) ([sha512.Size384]byte, error) { - return sha512.Sum384(data), nil -} diff --git a/pkg/errors/BUILD b/pkg/errors/BUILD deleted file mode 100644 index 36ea5da0c..000000000 --- a/pkg/errors/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "errors", - srcs = ["errors.go"], - visibility = ["//:sandbox"], - deps = ["//pkg/abi/linux/errno"], -) diff --git a/pkg/errors/errors_state_autogen.go b/pkg/errors/errors_state_autogen.go new file mode 100644 index 000000000..fd4538438 --- /dev/null +++ b/pkg/errors/errors_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package errors diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD deleted file mode 100644 index e73b0e28a..000000000 --- a/pkg/errors/linuxerr/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "linuxerr", - srcs = [ - "internal.go", - "linuxerr.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux/errno", - "//pkg/errors", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "linuxerr_test", - srcs = ["linuxerr_test.go"], - deps = [ - ":linuxerr", - "//pkg/abi/linux/errno", - "//pkg/errors", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/errors/linuxerr/linuxerr_state_autogen.go b/pkg/errors/linuxerr/linuxerr_state_autogen.go new file mode 100644 index 000000000..8106a9d9f --- /dev/null +++ b/pkg/errors/linuxerr/linuxerr_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package linuxerr diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go deleted file mode 100644 index df7cd1c5a..000000000 --- a/pkg/errors/linuxerr/linuxerr_test.go +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package linuxerr_test - -import ( - "errors" - "fmt" - "io" - "io/fs" - "syscall" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux/errno" - gErrors "gvisor.dev/gvisor/pkg/errors" - "gvisor.dev/gvisor/pkg/errors/linuxerr" -) - -var globalError error - -func BenchmarkAssignUnix(b *testing.B) { - for i := b.N; i > 0; i-- { - globalError = unix.EINVAL - } -} - -func BenchmarkAssignLinuxerr(b *testing.B) { - for i := b.N; i > 0; i-- { - globalError = linuxerr.EINVAL - } -} - -func BenchmarkCompareUnix(b *testing.B) { - globalError = unix.EAGAIN - j := 0 - for i := b.N; i > 0; i-- { - if globalError == unix.EINVAL { - j++ - } - } -} - -func BenchmarkCompareLinuxerr(b *testing.B) { - globalError = linuxerr.E2BIG - j := 0 - for i := b.N; i > 0; i-- { - if globalError == linuxerr.EINVAL { - j++ - } - } -} - -func BenchmarkSwitchUnix(b *testing.B) { - globalError = unix.EPERM - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case unix.EINVAL: - j++ - case unix.EINTR: - j += 2 - case unix.EAGAIN: - j += 3 - } - } -} - -func BenchmarkSwitchLinuxerr(b *testing.B) { - globalError = linuxerr.EPERM - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case linuxerr.EINVAL: - j++ - case linuxerr.EINTR: - j += 2 - case linuxerr.EAGAIN: - j += 3 - } - } -} - -func BenchmarkReturnUnix(b *testing.B) { - var localError error - f := func() error { - return unix.EINVAL - } - for i := b.N; i > 0; i-- { - localError = f() - } - if localError != nil { - return - } -} - -func BenchmarkReturnLinuxerr(b *testing.B) { - var localError error - f := func() error { - return linuxerr.EINVAL - } - for i := b.N; i > 0; i-- { - localError = f() - } - if localError != nil { - return - } -} - -func BenchmarkConvertUnixLinuxerr(b *testing.B) { - var localError error - for i := b.N; i > 0; i-- { - localError = linuxerr.ErrorFromErrno(errno.Errno(unix.EINVAL)) - } - if localError != nil { - return - } -} - -func BenchmarkConvertUnixLinuxerrZero(b *testing.B) { - var localError error - for i := b.N; i > 0; i-- { - localError = linuxerr.ErrorFromErrno(errno.Errno(0)) - } - if localError != nil { - return - } -} - -type translationTestTable struct { - errIn error - expectedBool bool - expectedTranslation *gErrors.Error -} - -func TestErrorTranslation(t *testing.T) { - testTable := []translationTestTable{ - { - errIn: linuxerr.ENOENT, - }, - { - errIn: unix.ENOENT, - }, - { - errIn: linuxerr.ErrInterrupted, - expectedBool: true, - expectedTranslation: linuxerr.EINTR, - }, - { - errIn: linuxerr.ERESTART_RESTARTBLOCK, - }, - { - errIn: errors.New("some new error"), - }, - } - for _, tt := range testTable { - t.Run(fmt.Sprintf("err: %v %T", tt.errIn, tt.errIn), func(t *testing.T) { - err, ok := linuxerr.TranslateError(tt.errIn) - if (!tt.expectedBool && err != nil) || (tt.expectedBool != ok) { - t.Fatalf("%v => %v %v expected %v err: nil", tt.errIn, err, ok, tt.expectedBool) - } else if err != tt.expectedTranslation { - t.Fatalf("%v => %v expected %v", tt.errIn, err, tt.expectedTranslation) - } - }) - } -} - -func TestSyscallErrnoToErrors(t *testing.T) { - for _, tc := range []struct { - errno syscall.Errno - err *gErrors.Error - }{ - {errno: syscall.EACCES, err: linuxerr.EACCES}, - {errno: syscall.EAGAIN, err: linuxerr.EAGAIN}, - {errno: syscall.EBADF, err: linuxerr.EBADF}, - {errno: syscall.EBUSY, err: linuxerr.EBUSY}, - {errno: syscall.EDOM, err: linuxerr.EDOM}, - {errno: syscall.EEXIST, err: linuxerr.EEXIST}, - {errno: syscall.EFAULT, err: linuxerr.EFAULT}, - {errno: syscall.EFBIG, err: linuxerr.EFBIG}, - {errno: syscall.EINTR, err: linuxerr.EINTR}, - {errno: syscall.EINVAL, err: linuxerr.EINVAL}, - {errno: syscall.EIO, err: linuxerr.EIO}, - {errno: syscall.ENOTDIR, err: linuxerr.ENOTDIR}, - {errno: syscall.ENOTTY, err: linuxerr.ENOTTY}, - {errno: syscall.EPERM, err: linuxerr.EPERM}, - {errno: syscall.EPIPE, err: linuxerr.EPIPE}, - {errno: syscall.ESPIPE, err: linuxerr.ESPIPE}, - {errno: syscall.EWOULDBLOCK, err: linuxerr.EAGAIN}, - } { - t.Run(tc.errno.Error(), func(t *testing.T) { - e := linuxerr.ErrorFromErrno(errno.Errno(tc.errno)) - if e != tc.err { - t.Fatalf("Mismatch errors: want: %+v (%d) got: %+v %d", tc.err, tc.err.Errno(), e, e.Errno()) - } - }) - } -} - -// TestEqualsMethod tests that the Equals method correctly compares syerror, -// unix.Errno and linuxerr. -// TODO (b/34162363): Remove this. -func TestEqualsMethod(t *testing.T) { - for _, tc := range []struct { - name string - linuxErr []*gErrors.Error - err []error - equal bool - }{ - { - name: "compare nil", - linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, - err: []error{nil, linuxerr.NOERROR, unix.Errno(0)}, - equal: true, - }, - { - name: "linuxerr nil error not", - linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, - err: []error{unix.Errno(1), linuxerr.EPERM, linuxerr.EACCES}, - equal: false, - }, - { - name: "linuxerr not nil error nil", - linuxErr: []*gErrors.Error{linuxerr.ENOENT}, - err: []error{nil, unix.Errno(0), linuxerr.NOERROR}, - equal: false, - }, - { - name: "equal errors", - linuxErr: []*gErrors.Error{linuxerr.ESRCH}, - err: []error{linuxerr.ESRCH, linuxerr.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, - equal: true, - }, - { - name: "unequal errors", - linuxErr: []*gErrors.Error{linuxerr.ENOENT}, - err: []error{linuxerr.ESRCH, linuxerr.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, - equal: false, - }, - { - name: "other error", - linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR, linuxerr.E2BIG, linuxerr.EINVAL}, - err: []error{fs.ErrInvalid, io.EOF}, - equal: false, - }, - } { - t.Run(tc.name, func(t *testing.T) { - for _, le := range tc.linuxErr { - for _, e := range tc.err { - if linuxerr.Equals(le, e) != tc.equal { - t.Fatalf("Expected %t from Equals method for linuxerr: %s %T and error: %s %T", tc.equal, le, le, e, e) - } - } - } - }) - } -} diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD deleted file mode 100644 index 56399232a..000000000 --- a/pkg/eventchannel/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "eventchannel", - srcs = [ - "event.go", - "event_any.go", - "processor.go", - "rate.go", - ], - visibility = ["//:sandbox"], - deps = [ - ":eventchannel_go_proto", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/sync", - "//pkg/unet", - "@org_golang_google_protobuf//encoding/prototext:go_default_library", - "@org_golang_google_protobuf//proto:go_default_library", - "@org_golang_google_protobuf//types/known/anypb:go_default_library", - "@org_golang_x_time//rate:go_default_library", - ], -) - -proto_library( - name = "eventchannel", - srcs = ["event.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "eventchannel_test", - srcs = ["event_test.go"], - library = ":eventchannel", - deps = [ - "//pkg/sync", - "@org_golang_google_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto deleted file mode 100644 index 4b24ac47c..000000000 --- a/pkg/eventchannel/event.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// DebugEvent encapsulates any other event protobuf in text format. This is -// useful because clients reading events emitted this way do not need to link -// the event protobufs to display them in a human-readable format. -message DebugEvent { - // Name of the inner message. - string name = 1; - // Text representation of the inner message content. - string text = 2; -} diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go deleted file mode 100644 index 0dd408f76..000000000 --- a/pkg/eventchannel/event_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eventchannel - -import ( - "fmt" - "testing" - "time" - - "google.golang.org/protobuf/proto" - "gvisor.dev/gvisor/pkg/sync" -) - -// testEmitter is an emitter that can be used in tests. It records all events -// emitted, and whether it has been closed. -type testEmitter struct { - // mu protects all fields below. - mu sync.Mutex - - // events contains all emitted events. - events []proto.Message - - // closed records whether Close() was called. - closed bool -} - -// Emit implements Emitter.Emit. -func (te *testEmitter) Emit(msg proto.Message) (bool, error) { - te.mu.Lock() - defer te.mu.Unlock() - te.events = append(te.events, msg) - return false, nil -} - -// Close implements Emitter.Close. -func (te *testEmitter) Close() error { - te.mu.Lock() - defer te.mu.Unlock() - if te.closed { - return fmt.Errorf("closed called twice") - } - te.closed = true - return nil -} - -// testMessage implements proto.Message for testing. -type testMessage struct { - proto.Message - - // name is the name of the message, used by tests to compare messages. - name string -} - -func TestMultiEmitter(t *testing.T) { - // Create three testEmitters, tied together in a multiEmitter. - me := &multiEmitter{} - var emitters []*testEmitter - for i := 0; i < 3; i++ { - te := &testEmitter{} - emitters = append(emitters, te) - me.AddEmitter(te) - } - - // Emit three messages to multiEmitter. - names := []string{"foo", "bar", "baz"} - for _, name := range names { - m := testMessage{name: name} - if _, err := me.Emit(m); err != nil { - t.Fatalf("me.Emit(%v) failed: %v", m, err) - } - } - - // All three emitters should have all three events. - for _, te := range emitters { - if got, want := len(te.events), len(names); got != want { - t.Fatalf("emitter got %d events, want %d", got, want) - } - for i, name := range names { - if got := te.events[i].(testMessage).name; got != name { - t.Errorf("emitter got message with name %q, want %q", got, name) - } - } - } - - // Close multiEmitter. - if err := me.Close(); err != nil { - t.Fatalf("me.Close() failed: %v", err) - } - - // All testEmitters should be closed. - for _, te := range emitters { - if !te.closed { - t.Errorf("te.closed got false, want true") - } - } -} - -func TestRateLimitedEmitter(t *testing.T) { - // Create a RateLimittedEmitter that wraps a testEmitter. - te := &testEmitter{} - max := float64(5) // events per second - burst := 10 // events - rle := RateLimitedEmitterFrom(te, max, burst) - - // Send 50 messages in one shot. - for i := 0; i < 50; i++ { - if _, err := rle.Emit(testMessage{}); err != nil { - t.Fatalf("rle.Emit failed: %v", err) - } - } - - // We should have received only 10 messages. - if got, want := len(te.events), 10; got != want { - t.Errorf("got %d events, want %d", got, want) - } - - // Sleep for a second and then send another 50. - time.Sleep(1 * time.Second) - for i := 0; i < 50; i++ { - if _, err := rle.Emit(testMessage{}); err != nil { - t.Fatalf("rle.Emit failed: %v", err) - } - } - - // We should have at least 5 more message, plus maybe a few more if the - // test ran slowly. - got, wantAtLeast, wantAtMost := len(te.events), 15, 20 - if got < wantAtLeast { - t.Errorf("got %d events, want at least %d", got, wantAtLeast) - } - if got > wantAtMost { - t.Errorf("got %d events, want at most %d", got, wantAtMost) - } -} diff --git a/pkg/eventchannel/eventchannel_go_proto/event.pb.go b/pkg/eventchannel/eventchannel_go_proto/event.pb.go new file mode 100644 index 000000000..1d0812479 --- /dev/null +++ b/pkg/eventchannel/eventchannel_go_proto/event.pb.go @@ -0,0 +1,156 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/eventchannel/event.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type DebugEvent struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Text string `protobuf:"bytes,2,opt,name=text,proto3" json:"text,omitempty"` +} + +func (x *DebugEvent) Reset() { + *x = DebugEvent{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_eventchannel_event_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DebugEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DebugEvent) ProtoMessage() {} + +func (x *DebugEvent) ProtoReflect() protoreflect.Message { + mi := &file_pkg_eventchannel_event_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DebugEvent.ProtoReflect.Descriptor instead. +func (*DebugEvent) Descriptor() ([]byte, []int) { + return file_pkg_eventchannel_event_proto_rawDescGZIP(), []int{0} +} + +func (x *DebugEvent) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *DebugEvent) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +var File_pkg_eventchannel_event_proto protoreflect.FileDescriptor + +var file_pkg_eventchannel_event_proto_rawDesc = []byte{ + 0x0a, 0x1c, 0x70, 0x6b, 0x67, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x63, 0x68, 0x61, 0x6e, 0x6e, + 0x65, 0x6c, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, + 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x22, 0x34, 0x0a, 0x0a, 0x44, 0x65, 0x62, 0x75, 0x67, 0x45, + 0x76, 0x65, 0x6e, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_eventchannel_event_proto_rawDescOnce sync.Once + file_pkg_eventchannel_event_proto_rawDescData = file_pkg_eventchannel_event_proto_rawDesc +) + +func file_pkg_eventchannel_event_proto_rawDescGZIP() []byte { + file_pkg_eventchannel_event_proto_rawDescOnce.Do(func() { + file_pkg_eventchannel_event_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_eventchannel_event_proto_rawDescData) + }) + return file_pkg_eventchannel_event_proto_rawDescData +} + +var file_pkg_eventchannel_event_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_pkg_eventchannel_event_proto_goTypes = []interface{}{ + (*DebugEvent)(nil), // 0: gvisor.DebugEvent +} +var file_pkg_eventchannel_event_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pkg_eventchannel_event_proto_init() } +func file_pkg_eventchannel_event_proto_init() { + if File_pkg_eventchannel_event_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_eventchannel_event_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DebugEvent); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_eventchannel_event_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_eventchannel_event_proto_goTypes, + DependencyIndexes: file_pkg_eventchannel_event_proto_depIdxs, + MessageInfos: file_pkg_eventchannel_event_proto_msgTypes, + }.Build() + File_pkg_eventchannel_event_proto = out.File + file_pkg_eventchannel_event_proto_rawDesc = nil + file_pkg_eventchannel_event_proto_goTypes = nil + file_pkg_eventchannel_event_proto_depIdxs = nil +} diff --git a/pkg/eventchannel/eventchannel_state_autogen.go b/pkg/eventchannel/eventchannel_state_autogen.go new file mode 100644 index 000000000..d6ace927f --- /dev/null +++ b/pkg/eventchannel/eventchannel_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package eventchannel diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD deleted file mode 100644 index c7abd9655..000000000 --- a/pkg/fd/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fd", - srcs = ["fd.go"], - visibility = ["//visibility:public"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) - -go_test( - name = "fd_test", - size = "small", - srcs = ["fd_test.go"], - library = ":fd", - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/fd/fd_state_autogen.go b/pkg/fd/fd_state_autogen.go new file mode 100644 index 000000000..5ad412976 --- /dev/null +++ b/pkg/fd/fd_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fd diff --git a/pkg/fd/fd_test.go b/pkg/fd/fd_test.go deleted file mode 100644 index da71364d4..000000000 --- a/pkg/fd/fd_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fd - -import ( - "math" - "os" - "testing" - - "golang.org/x/sys/unix" -) - -func TestSetNegOne(t *testing.T) { - type entry struct { - name string - file *FD - fn func() error - } - var tests []entry - - fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("unix.Socket:", err) - } - f1 := New(fd) - tests = append(tests, entry{ - "Release", - f1, - func() error { - return unix.Close(f1.Release()) - }, - }) - - fd, err = unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("unix.Socket:", err) - } - f2 := New(fd) - tests = append(tests, entry{ - "Close", - f2, - f2.Close, - }) - - for _, test := range tests { - if err := test.fn(); err != nil { - t.Errorf("%s: %v", test.name, err) - continue - } - if fd := test.file.FD(); fd != -1 { - t.Errorf("%s: got FD() = %d, want = -1", test.name, fd) - } - } -} - -func TestStartsNegOne(t *testing.T) { - type entry struct { - name string - file *FD - } - - tests := []entry{ - {"-1", New(-1)}, - {"-2", New(-2)}, - {"MinInt32", New(math.MinInt32)}, - {"MinInt64", New(math.MinInt64)}, - } - - for _, test := range tests { - if fd := test.file.FD(); fd != -1 { - t.Errorf("%s: got FD() = %d, want = -1", test.name, fd) - } - } -} - -func TestFileDotFile(t *testing.T) { - fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("unix.Socket:", err) - } - - f := New(fd) - of, err := f.File() - if err != nil { - t.Fatalf("File got err %v want nil", err) - } - - if ofd, nfd := int(of.Fd()), f.FD(); ofd == nfd || ofd == -1 { - // Try not to double close the FD. - f.Release() - - t.Fatalf("got %#v.File().Fd() = %d, want new FD", f, ofd) - } - - f.Close() - of.Close() -} - -func TestFileDotFileError(t *testing.T) { - f := &FD{ReadWriter{-2}} - - if of, err := f.File(); err == nil { - t.Errorf("File %v got nil err want non-nil", of) - of.Close() - } -} - -func TestNewFromFile(t *testing.T) { - f, err := NewFromFile(os.Stdin) - if err != nil { - t.Fatalf("NewFromFile got err %v want nil", err) - } - if nfd, ofd := f.FD(), int(os.Stdin.Fd()); nfd == -1 || nfd == ofd { - t.Errorf("got FD() = %d, want = new FD (old FD was %d)", nfd, ofd) - } - f.Close() -} - -func TestNewFromFileError(t *testing.T) { - f, err := NewFromFile(nil) - if err == nil { - t.Errorf("NewFromFile got %v with nil err want non-nil", f) - f.Close() - } -} diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD deleted file mode 100644 index d240132e5..000000000 --- a/pkg/fdchannel/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "fdchannel", - srcs = ["fdchannel_unsafe.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/gohacks", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdchannel_test", - size = "small", - srcs = ["fdchannel_test.go"], - library = ":fdchannel", - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go deleted file mode 100644 index 9616e30c5..000000000 --- a/pkg/fdchannel/fdchannel_test.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdchannel - -import ( - "io/ioutil" - "os" - "syscall" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSendRecvFD(t *testing.T) { - sendFile, err := ioutil.TempFile("", "fdchannel_test_") - if err != nil { - t.Fatalf("failed to create temporary file: %v", err) - } - defer sendFile.Close() - - chanFDs, err := NewConnectedSockets() - if err != nil { - t.Fatalf("failed to create fdchannel sockets: %v", err) - } - sendEP := NewEndpoint(chanFDs[0]) - defer sendEP.Destroy() - recvEP := NewEndpoint(chanFDs[1]) - defer recvEP.Destroy() - - recvFD, err := recvEP.RecvFDNonblock() - if err != unix.EAGAIN && err != unix.EWOULDBLOCK { - t.Errorf("RecvFDNonblock before SendFD: got (%d, %v), wanted (<unspecified>, EAGAIN or EWOULDBLOCK", recvFD, err) - } - - if err := sendEP.SendFD(int(sendFile.Fd())); err != nil { - t.Fatalf("SendFD failed: %v", err) - } - recvFD, err = recvEP.RecvFD() - if err != nil { - t.Fatalf("RecvFD failed: %v", err) - } - recvFile := os.NewFile(uintptr(recvFD), "received file") - defer recvFile.Close() - - sendInfo, err := sendFile.Stat() - if err != nil { - t.Fatalf("failed to stat sent file: %v", err) - } - sendInfoSys := sendInfo.Sys() - sendStat, ok := sendInfoSys.(*syscall.Stat_t) - if !ok { - t.Fatalf("sent file's FileInfo is backed by unknown type %T", sendInfoSys) - } - - recvInfo, err := recvFile.Stat() - if err != nil { - t.Fatalf("failed to stat received file: %v", err) - } - recvInfoSys := recvInfo.Sys() - recvStat, ok := recvInfoSys.(*syscall.Stat_t) - if !ok { - t.Fatalf("received file's FileInfo is backed by unknown type %T", recvInfoSys) - } - - if sendStat.Dev != recvStat.Dev || sendStat.Ino != recvStat.Ino { - t.Errorf("sent file (dev=%d, ino=%d) does not match received file (dev=%d, ino=%d)", sendStat.Dev, sendStat.Ino, recvStat.Dev, recvStat.Ino) - } -} - -func TestShutdownThenRecvFD(t *testing.T) { - sendFile, err := ioutil.TempFile("", "fdchannel_test_") - if err != nil { - t.Fatalf("failed to create temporary file: %v", err) - } - defer sendFile.Close() - - chanFDs, err := NewConnectedSockets() - if err != nil { - t.Fatalf("failed to create fdchannel sockets: %v", err) - } - sendEP := NewEndpoint(chanFDs[0]) - defer sendEP.Destroy() - recvEP := NewEndpoint(chanFDs[1]) - defer recvEP.Destroy() - - recvEP.Shutdown() - if _, err := recvEP.RecvFD(); err == nil { - t.Error("RecvFD succeeded unexpectedly") - } -} - -func TestRecvFDThenShutdown(t *testing.T) { - sendFile, err := ioutil.TempFile("", "fdchannel_test_") - if err != nil { - t.Fatalf("failed to create temporary file: %v", err) - } - defer sendFile.Close() - - chanFDs, err := NewConnectedSockets() - if err != nil { - t.Fatalf("failed to create fdchannel sockets: %v", err) - } - sendEP := NewEndpoint(chanFDs[0]) - defer sendEP.Destroy() - recvEP := NewEndpoint(chanFDs[1]) - defer recvEP.Destroy() - - var receiverWG sync.WaitGroup - receiverWG.Add(1) - go func() { - defer receiverWG.Done() - if _, err := recvEP.RecvFD(); err == nil { - t.Error("RecvFD succeeded unexpectedly") - } - }() - defer receiverWG.Wait() - time.Sleep(time.Second) // to ensure recvEP.RecvFD() has blocked - recvEP.Shutdown() -} diff --git a/pkg/fdchannel/fdchannel_unsafe_state_autogen.go b/pkg/fdchannel/fdchannel_unsafe_state_autogen.go new file mode 100644 index 000000000..6ad5b7b43 --- /dev/null +++ b/pkg/fdchannel/fdchannel_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package fdchannel diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD deleted file mode 100644 index 235dcc490..000000000 --- a/pkg/fdnotifier/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fdnotifier", - srcs = [ - "fdnotifier.go", - "poll_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/sync", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/fdnotifier/fdnotifier_state_autogen.go b/pkg/fdnotifier/fdnotifier_state_autogen.go new file mode 100644 index 000000000..70dfa86d3 --- /dev/null +++ b/pkg/fdnotifier/fdnotifier_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package fdnotifier diff --git a/pkg/fdnotifier/fdnotifier_unsafe_state_autogen.go b/pkg/fdnotifier/fdnotifier_unsafe_state_autogen.go new file mode 100644 index 000000000..70dfa86d3 --- /dev/null +++ b/pkg/fdnotifier/fdnotifier_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package fdnotifier diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD deleted file mode 100644 index c810c7946..000000000 --- a/pkg/flipcall/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "flipcall", - srcs = [ - "ctrl_futex.go", - "flipcall.go", - "flipcall_unsafe.go", - "futex_linux.go", - "io.go", - "packet_window.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/memutil", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "flipcall_test", - size = "small", - srcs = [ - "flipcall_example_test.go", - "flipcall_test.go", - ], - library = ":flipcall", - deps = ["//pkg/sync"], -) diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go deleted file mode 100644 index 2e28a149a..000000000 --- a/pkg/flipcall/flipcall_example_test.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flipcall - -import ( - "bytes" - "fmt" - - "gvisor.dev/gvisor/pkg/sync" -) - -func Example() { - const ( - reqPrefix = "request " - respPrefix = "response " - count = 3 - maxMessageLen = len(respPrefix) + 1 // 1 digit - ) - - pwa, err := NewPacketWindowAllocator() - if err != nil { - panic(err) - } - defer pwa.Destroy() - pwd, err := pwa.Allocate(PacketWindowLengthForDataCap(uint32(maxMessageLen))) - if err != nil { - panic(err) - } - var clientEP Endpoint - if err := clientEP.Init(ClientSide, pwd); err != nil { - panic(err) - } - defer clientEP.Destroy() - var serverEP Endpoint - if err := serverEP.Init(ServerSide, pwd); err != nil { - panic(err) - } - defer serverEP.Destroy() - - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - i := 0 - var buf bytes.Buffer - // wait for first request - n, err := serverEP.RecvFirst() - if err != nil { - return - } - for { - // read request - buf.Reset() - buf.Write(serverEP.Data()[:n]) - fmt.Println(buf.String()) - // write response - buf.Reset() - fmt.Fprintf(&buf, "%s%d", respPrefix, i) - copy(serverEP.Data(), buf.Bytes()) - // send response and wait for next request - n, err = serverEP.SendRecv(uint32(buf.Len())) - if err != nil { - return - } - i++ - } - }() - defer func() { - serverEP.Shutdown() - serverRun.Wait() - }() - - // establish connection as client - if err := clientEP.Connect(); err != nil { - panic(err) - } - var buf bytes.Buffer - for i := 0; i < count; i++ { - // write request - buf.Reset() - fmt.Fprintf(&buf, "%s%d", reqPrefix, i) - copy(clientEP.Data(), buf.Bytes()) - // send request and wait for response - n, err := clientEP.SendRecv(uint32(buf.Len())) - if err != nil { - panic(err) - } - // read response - buf.Reset() - buf.Write(clientEP.Data()[:n]) - fmt.Println(buf.String()) - } - - // Output: - // request 0 - // response 0 - // request 1 - // response 1 - // request 2 - // response 2 -} diff --git a/pkg/flipcall/flipcall_linux_state_autogen.go b/pkg/flipcall/flipcall_linux_state_autogen.go new file mode 100644 index 000000000..7c9d0e9db --- /dev/null +++ b/pkg/flipcall/flipcall_linux_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package flipcall diff --git a/pkg/flipcall/flipcall_state_autogen.go b/pkg/flipcall/flipcall_state_autogen.go new file mode 100644 index 000000000..0557f3660 --- /dev/null +++ b/pkg/flipcall/flipcall_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package flipcall diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go deleted file mode 100644 index 33fd55a44..000000000 --- a/pkg/flipcall/flipcall_test.go +++ /dev/null @@ -1,405 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flipcall - -import ( - "runtime" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -var testPacketWindowSize = pageSize - -type testConnection struct { - pwa PacketWindowAllocator - clientEP Endpoint - serverEP Endpoint -} - -func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection { - c := &testConnection{} - if err := c.pwa.Init(); err != nil { - tb.Fatalf("failed to create PacketWindowAllocator: %v", err) - } - pwd, err := c.pwa.Allocate(testPacketWindowSize) - if err != nil { - c.pwa.Destroy() - tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) - } - if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil { - c.pwa.Destroy() - tb.Fatalf("failed to create client Endpoint: %v", err) - } - if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil { - c.pwa.Destroy() - c.clientEP.Destroy() - tb.Fatalf("failed to create server Endpoint: %v", err) - } - return c -} - -func newTestConnection(tb testing.TB) *testConnection { - return newTestConnectionWithOptions(tb, nil, nil) -} - -func (c *testConnection) destroy() { - c.pwa.Destroy() - c.clientEP.Destroy() - c.serverEP.Destroy() -} - -func testSendRecv(t *testing.T, c *testConnection) { - // This shared variable is used to confirm that synchronization between - // flipcall endpoints is visible to the Go race detector. - state := 0 - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - t.Logf("server Endpoint waiting for packet 1") - if _, err := c.serverEP.RecvFirst(); err != nil { - t.Errorf("server Endpoint.RecvFirst() failed: %v", err) - return - } - state++ - if state != 2 { - t.Errorf("shared state counter: got %d, wanted 2", state) - } - t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") - if _, err := c.serverEP.SendRecv(0); err != nil { - t.Errorf("server Endpoint.SendRecv() failed: %v", err) - return - } - state++ - if state != 4 { - t.Errorf("shared state counter: got %d, wanted 4", state) - } - t.Logf("server Endpoint got packet 3") - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - - t.Logf("client Endpoint establishing connection") - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - state++ - if state != 1 { - t.Errorf("shared state counter: got %d, wanted 1", state) - } - t.Logf("client Endpoint sending packet 1 and waiting for packet 2") - if _, err := c.clientEP.SendRecv(0); err != nil { - t.Fatalf("client Endpoint.SendRecv() failed: %v", err) - } - state++ - if state != 3 { - t.Errorf("shared state counter: got %d, wanted 3", state) - } - t.Logf("client Endpoint got packet 2, sending packet 3") - if err := c.clientEP.SendLast(0); err != nil { - t.Fatalf("client Endpoint.SendLast() failed: %v", err) - } - t.Logf("waiting for server goroutine to complete") - serverRun.Wait() -} - -func TestSendRecv(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testSendRecv(t, c) -} - -func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - if remoteShutdown { - c.serverEP.Shutdown() - } else { - c.clientEP.Shutdown() - } - if err := c.clientEP.Connect(); err == nil { - t.Errorf("client Endpoint.Connect() succeeded unexpectedly") - } -} - -func TestShutdownBeforeConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeConnect(t, c, false) -} - -func TestShutdownBeforeConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeConnect(t, c, true) -} - -func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - var clientRun sync.WaitGroup - clientRun.Add(1) - go func() { - defer clientRun.Done() - if err := c.clientEP.Connect(); err == nil { - t.Errorf("client Endpoint.Connect() succeeded unexpectedly") - } - }() - time.Sleep(time.Second) // to allow c.clientEP.Connect() to block - if remoteShutdown { - c.serverEP.Shutdown() - } else { - c.clientEP.Shutdown() - } - clientRun.Wait() -} - -func TestShutdownDuringConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringConnect(t, c, false) -} - -func TestShutdownDuringConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringConnect(t, c, true) -} - -func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) { - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - if _, err := c.serverEP.RecvFirst(); err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } -} - -func TestShutdownBeforeRecvFirstLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeRecvFirst(t, c, false) -} - -func TestShutdownBeforeRecvFirstRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownBeforeRecvFirst(t, c, true) -} - -func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } - }() - time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - serverRun.Wait() -} - -func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstBeforeConnect(t, c, false) -} - -func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstBeforeConnect(t, c, true) -} - -func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err == nil { - t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") - } - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - serverRun.Wait() -} - -func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstAfterConnect(t, c, false) -} - -func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringRecvFirstAfterConnect(t, c, true) -} - -func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err != nil { - t.Errorf("server Endpoint.RecvFirst() failed: %v", err) - } - // At this point, the client must be blocked in c.clientEP.SendRecv(). - if remoteShutdown { - c.serverEP.Shutdown() - } else { - c.clientEP.Shutdown() - } - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - if _, err := c.clientEP.SendRecv(0); err == nil { - t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly") - } -} - -func TestShutdownDuringClientSendRecvLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringClientSendRecv(t, c, false) -} - -func TestShutdownDuringClientSendRecvRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringClientSendRecv(t, c, true) -} - -func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if _, err := c.serverEP.RecvFirst(); err != nil { - t.Errorf("server Endpoint.RecvFirst() failed: %v", err) - return - } - if _, err := c.serverEP.SendRecv(0); err == nil { - t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") - } - }() - defer func() { - // Ensure that the server goroutine is cleaned up before - // c.serverEP.Destroy(), even if the test fails. - c.serverEP.Shutdown() - serverRun.Wait() - }() - if err := c.clientEP.Connect(); err != nil { - t.Fatalf("client Endpoint.Connect() failed: %v", err) - } - if _, err := c.clientEP.SendRecv(0); err != nil { - t.Fatalf("client Endpoint.SendRecv() failed: %v", err) - } - time.Sleep(time.Second) // to allow serverEP.SendRecv() to block - if remoteShutdown { - c.clientEP.Shutdown() - } else { - c.serverEP.Shutdown() - } - serverRun.Wait() -} - -func TestShutdownDuringServerSendRecvLocal(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringServerSendRecv(t, c, false) -} - -func TestShutdownDuringServerSendRecvRemote(t *testing.T) { - c := newTestConnection(t) - defer c.destroy() - testShutdownDuringServerSendRecv(t, c, true) -} - -func benchmarkSendRecv(b *testing.B, c *testConnection) { - var serverRun sync.WaitGroup - serverRun.Add(1) - go func() { - defer serverRun.Done() - if b.N == 0 { - return - } - if _, err := c.serverEP.RecvFirst(); err != nil { - b.Errorf("server Endpoint.RecvFirst() failed: %v", err) - return - } - for i := 1; i < b.N; i++ { - if _, err := c.serverEP.SendRecv(0); err != nil { - b.Errorf("server Endpoint.SendRecv() failed: %v", err) - return - } - } - if err := c.serverEP.SendLast(0); err != nil { - b.Errorf("server Endpoint.SendLast() failed: %v", err) - } - }() - defer func() { - c.serverEP.Shutdown() - serverRun.Wait() - }() - - if err := c.clientEP.Connect(); err != nil { - b.Fatalf("client Endpoint.Connect() failed: %v", err) - } - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if _, err := c.clientEP.SendRecv(0); err != nil { - b.Fatalf("client Endpoint.SendRecv() failed: %v", err) - } - } - b.StopTimer() -} - -func BenchmarkSendRecv(b *testing.B) { - c := newTestConnection(b) - defer c.destroy() - benchmarkSendRecv(b, c) -} diff --git a/pkg/flipcall/flipcall_unsafe_state_autogen.go b/pkg/flipcall/flipcall_unsafe_state_autogen.go new file mode 100644 index 000000000..0e03c2a65 --- /dev/null +++ b/pkg/flipcall/flipcall_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package flipcall diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD deleted file mode 100644 index 67dd1e225..000000000 --- a/pkg/fspath/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) - -go_library( - name = "fspath", - srcs = [ - "builder.go", - "fspath.go", - ], - deps = [ - "//pkg/gohacks", - ], -) - -go_test( - name = "fspath_test", - size = "small", - srcs = [ - "builder_test.go", - "fspath_test.go", - ], - library = ":fspath", -) diff --git a/pkg/fspath/builder_test.go b/pkg/fspath/builder_test.go deleted file mode 100644 index 22f890273..000000000 --- a/pkg/fspath/builder_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fspath - -import ( - "testing" -) - -func TestBuilder(t *testing.T) { - type testCase struct { - pcs []string // path components in reverse order - after string - want string - } - tests := []testCase{ - { - // Empty case. - }, - { - pcs: []string{"foo"}, - want: "foo", - }, - { - pcs: []string{"foo", "bar", "baz"}, - want: "baz/bar/foo", - }, - { - pcs: []string{"foo", "bar"}, - after: " (deleted)", - want: "bar/foo (deleted)", - }, - } - - for _, test := range tests { - t.Run(test.want, func(t *testing.T) { - var b Builder - for _, pc := range test.pcs { - b.PrependComponent(pc) - } - b.AppendString(test.after) - if got := b.String(); got != test.want { - t.Errorf("got %q, wanted %q", got, test.want) - } - }) - } -} diff --git a/pkg/fspath/fspath_state_autogen.go b/pkg/fspath/fspath_state_autogen.go new file mode 100644 index 000000000..6ceea8003 --- /dev/null +++ b/pkg/fspath/fspath_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fspath diff --git a/pkg/fspath/fspath_test.go b/pkg/fspath/fspath_test.go deleted file mode 100644 index d5e9a549a..000000000 --- a/pkg/fspath/fspath_test.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fspath - -import ( - "reflect" - "strings" - "testing" -) - -func TestParseIteratorPartialPathnames(t *testing.T) { - path := Parse("/foo//bar///baz////") - // Parse strips leading slashes, and records their presence as - // Path.Absolute. - if !path.Absolute { - t.Errorf("Path.Absolute: got false, wanted true") - } - // Parse strips trailing slashes, and records their presence as Path.Dir. - if !path.Dir { - t.Errorf("Path.Dir: got false, wanted true") - } - // The first Iterator.partialPathname is the input pathname, with leading - // and trailing slashes stripped. - it := path.Begin - if want := "foo//bar///baz"; it.partialPathname != want { - t.Errorf("first Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - // Successive Iterator.partialPathnames remove the leading path component - // and following slashes, until we run out of path components and get a - // terminal Iterator. - it = it.Next() - if want := "bar///baz"; it.partialPathname != want { - t.Errorf("second Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - it = it.Next() - if want := "baz"; it.partialPathname != want { - t.Errorf("third Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - it = it.Next() - if want := ""; it.partialPathname != want { - t.Errorf("fourth Iterator.partialPathname: got %q, wanted %q", it.partialPathname, want) - } - if it.Ok() { - t.Errorf("fourth Iterator.Ok(): got true, wanted false") - } -} - -func TestParse(t *testing.T) { - type testCase struct { - pathname string - relpath []string - abs bool - dir bool - } - tests := []testCase{ - { - pathname: "", - relpath: []string{}, - abs: false, - dir: false, - }, - { - pathname: "/", - relpath: []string{}, - abs: true, - dir: true, - }, - { - pathname: "//", - relpath: []string{}, - abs: true, - dir: true, - }, - } - for _, sep := range []string{"/", "//"} { - for _, abs := range []bool{false, true} { - for _, dir := range []bool{false, true} { - for _, pcs := range [][]string{ - // single path component - {"foo"}, - // multiple path components, including non-UTF-8 - {".", "foo", "..", "\xe6", "bar"}, - } { - prefix := "" - if abs { - prefix = sep - } - suffix := "" - if dir { - suffix = sep - } - tests = append(tests, testCase{ - pathname: prefix + strings.Join(pcs, sep) + suffix, - relpath: pcs, - abs: abs, - dir: dir, - }) - } - } - } - } - - for _, test := range tests { - t.Run(test.pathname, func(t *testing.T) { - p := Parse(test.pathname) - t.Logf("pathname %q => path %q", test.pathname, p) - if p.Absolute != test.abs { - t.Errorf("path absoluteness: got %v, wanted %v", p.Absolute, test.abs) - } - if p.Dir != test.dir { - t.Errorf("path must resolve to a directory: got %v, wanted %v", p.Dir, test.dir) - } - pcs := []string{} - for pit := p.Begin; pit.Ok(); pit = pit.Next() { - pcs = append(pcs, pit.String()) - } - if !reflect.DeepEqual(pcs, test.relpath) { - t.Errorf("relative path: got %v, wanted %v", pcs, test.relpath) - } - }) - } -} diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD deleted file mode 100644 index b4e05f922..000000000 --- a/pkg/gohacks/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gohacks", - srcs = [ - "gohacks_unsafe.go", - ], - stateify = False, - visibility = ["//:sandbox"], -) - -go_test( - name = "gohacks_test", - size = "small", - srcs = ["gohacks_test.go"], - library = ":gohacks", - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/gohacks/gohacks_test.go b/pkg/gohacks/gohacks_test.go deleted file mode 100644 index e18c8abc7..000000000 --- a/pkg/gohacks/gohacks_test.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gohacks - -import ( - "io/ioutil" - "math/rand" - "os" - "runtime/debug" - "testing" - - "golang.org/x/sys/unix" -) - -func randBuf(size int) []byte { - b := make([]byte, size) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular -// dependency. -const pageSize = 4096 - -func testCopy(dst, src []byte) (panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - debug.SetPanicOnFault(true) - copy(dst, src) - return panicked -} - -func TestSegVOnMemmove(t *testing.T) { - // Test that SIGSEGVs received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer unix.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} - -func TestSigbusOnMemmove(t *testing.T) { - // Test that SIGBUS received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - os.Remove(f.Name()) - defer f.Close() - - a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer unix.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD deleted file mode 100644 index 08832a8ae..000000000 --- a/pkg/goid/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "goid", - srcs = [ - "goid.go", - "goid_amd64.s", - "goid_arm64.s", - ], - stateify = False, - visibility = ["//visibility:public"], -) - -go_test( - name = "goid_test", - size = "small", - srcs = [ - "goid_test.go", - ], - library = ":goid", -) diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go deleted file mode 100644 index 54be11d63..000000000 --- a/pkg/goid/goid_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package goid - -import ( - "runtime" - "sync" - "testing" - "time" -) - -func TestUniquenessAndConsistency(t *testing.T) { - const ( - numGoroutines = 5000 - - // maxID is not an intrinsic property of goroutine IDs; it is only a - // property of how the Go runtime currently assigns them. Future - // changes to the Go runtime may require that maxID be raised, or that - // assertions regarding it be removed entirely. - maxID = numGoroutines + 1000 - ) - - var ( - goidsMu sync.Mutex - goids = make(map[int64]struct{}) - checkedWG sync.WaitGroup - exitCh = make(chan struct{}) - ) - for i := 0; i < numGoroutines; i++ { - checkedWG.Add(1) - go func() { - id := Get() - if id > maxID { - t.Errorf("observed unexpectedly large goroutine ID %d", id) - } - goidsMu.Lock() - if _, dup := goids[id]; dup { - t.Errorf("observed duplicate goroutine ID %d", id) - } - goids[id] = struct{}{} - goidsMu.Unlock() - checkedWG.Done() - for { - if curID := Get(); curID != id { - t.Errorf("goroutine ID changed from %d to %d", id, curID) - // Don't spam logs by repeating the check; wait quietly for - // the test to finish. - <-exitCh - return - } - // Check if the test is over. - select { - case <-exitCh: - return - default: - } - // Yield to other goroutines, and possibly migrate to another P. - runtime.Gosched() - } - }() - } - // Wait for all goroutines to perform uniqueness checks. - checkedWG.Wait() - // Wait for an additional second to allow goroutines to spin checking for - // ID consistency. - time.Sleep(time.Second) - // Request that all goroutines exit. - close(exitCh) -} diff --git a/pkg/hostarch/BUILD b/pkg/hostarch/BUILD deleted file mode 100644 index 1e8def4d9..000000000 --- a/pkg/hostarch/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "hostarch", - prefix = "Addr", - template = "//pkg/segment:generic_range", - types = { - "T": "Addr", - }, -) - -go_test( - name = "hostarch_test", - size = "small", - srcs = [ - "addr_range_seq_test.go", - ], - library = ":hostarch", -) - -go_library( - name = "hostarch", - srcs = [ - "access_type.go", - "addr.go", - "addr_range.go", - "addr_range_seq_unsafe.go", - "hostarch.go", - "hostarch_arm64.go", - "hostarch_x86.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/gohacks", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/segment/range.go b/pkg/hostarch/addr_range.go index b6fa96e81..8a771d6fc 100644 --- a/pkg/segment/range.go +++ b/pkg/hostarch/addr_range.go @@ -1,59 +1,42 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -// T is a required type parameter that must be an integral type. -type T uint64 +package hostarch // A Range represents a contiguous range of T. // // +stateify savable -type Range struct { +type AddrRange struct { // Start is the inclusive start of the range. - Start T + Start Addr // End is the exclusive end of the range. - End T + End Addr } // WellFormed returns true if r.Start <= r.End. All other methods on a Range // require that the Range is well-formed. // //go:nosplit -func (r Range) WellFormed() bool { +func (r AddrRange) WellFormed() bool { return r.Start <= r.End } // Length returns the length of the range. // //go:nosplit -func (r Range) Length() T { +func (r AddrRange) Length() Addr { return r.End - r.Start } // Contains returns true if r contains x. // //go:nosplit -func (r Range) Contains(x T) bool { +func (r AddrRange) Contains(x Addr) bool { return r.Start <= x && x < r.End } // Overlaps returns true if r and r2 overlap. // //go:nosplit -func (r Range) Overlaps(r2 Range) bool { +func (r AddrRange) Overlaps(r2 AddrRange) bool { return r.Start < r2.End && r2.Start < r.End } @@ -61,7 +44,7 @@ func (r Range) Overlaps(r2 Range) bool { // contained within r. // //go:nosplit -func (r Range) IsSupersetOf(r2 Range) bool { +func (r AddrRange) IsSupersetOf(r2 AddrRange) bool { return r.Start <= r2.Start && r.End >= r2.End } @@ -70,7 +53,7 @@ func (r Range) IsSupersetOf(r2 Range) bool { // bounds, but for which Length() == 0. // //go:nosplit -func (r Range) Intersect(r2 Range) Range { +func (r AddrRange) Intersect(r2 AddrRange) AddrRange { if r.Start < r2.Start { r.Start = r2.Start } @@ -88,6 +71,6 @@ func (r Range) Intersect(r2 Range) Range { // non-zero length. // //go:nosplit -func (r Range) CanSplitAt(x T) bool { +func (r AddrRange) CanSplitAt(x Addr) bool { return r.Contains(x) && r.Start < x } diff --git a/pkg/hostarch/addr_range_seq_test.go b/pkg/hostarch/addr_range_seq_test.go deleted file mode 100644 index 5726dfd19..000000000 --- a/pkg/hostarch/addr_range_seq_test.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package hostarch - -import ( - "testing" -) - -var addrRangeSeqTests = []struct { - desc string - ranges []AddrRange -}{ - { - desc: "Empty sequence", - }, - { - desc: "Single empty AddrRange", - ranges: []AddrRange{ - {0x10, 0x10}, - }, - }, - { - desc: "Single non-empty AddrRange of length 1", - ranges: []AddrRange{ - {0x10, 0x11}, - }, - }, - { - desc: "Single non-empty AddrRange of length 2", - ranges: []AddrRange{ - {0x10, 0x12}, - }, - }, - { - desc: "Multiple non-empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - }, - }, - { - desc: "Multiple AddrRanges including empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x10}, - {0x20, 0x20}, - {0x30, 0x33}, - {0x40, 0x44}, - {0x50, 0x50}, - {0x60, 0x60}, - {0x70, 0x77}, - {0x80, 0x88}, - {0x90, 0x90}, - {0xa0, 0xa0}, - }, - }, -} - -func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) { - var wantLen int64 - for _, ar := range wantRanges { - wantLen += int64(ar.Length()) - } - - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - ars = ars.Tail() - wantLen -= int64(got.Length()) - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - if gotN := ars.NumRanges(); gotN != 0 { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN) - } -} - -func TestAddrRangeSeqTailIteration(t *testing.T) { - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges) - }) - } -} - -func TestAddrRangeSeqDropFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.DropFirst(1), ars; got != want { - t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) { - // Tests AddrRangeSeq iteration using Head/DropFirst, simulating - // I/O-per-AddrRange. - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - // Figure out what AddrRanges we expect to see. - var wantLen int64 - var wantRanges []AddrRange - for _, ar := range test.ranges { - wantLen += int64(ar.Length()) - wantRanges = append(wantRanges, ar) - if ar.Length() == 0 { - // We "do" 0 bytes of I/O and then call DropFirst(0), - // advancing to the next AddrRange. - continue - } - // Otherwise we "do" 1 byte of I/O and then call DropFirst(1), - // advancing the AddrRange by 1 byte, or to the next AddrRange - // if this one is exhausted. - for ar.Start++; ar.Length() != 0; ar.Start++ { - wantRanges = append(wantRanges, ar) - } - } - t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen) - - ars := AddrRangeSeqFromSlice(test.ranges) - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - if got.Length() == 0 { - ars = ars.DropFirst(0) - } else { - ars = ars.DropFirst(1) - wantLen-- - } - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - }) - } -} - -func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.TakeFirst(1), ars; got != want { - t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqTakeFirst(t *testing.T) { - ranges := []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - {0x30, 0x30}, - {0x40, 0x44}, - {0x50, 0x55}, - {0x60, 0x60}, - {0x70, 0x77}, - } - ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5) - want := []AddrRange{ - {0x10, 0x11}, // +1 byte (total 1 byte), not truncated - {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated - {0x30, 0x30}, // +0 bytes (total 3 bytes), no change - {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated - {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated - {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change) - {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated - } - testAddrRangeSeqEqualityWithTailIteration(t, ars, want) -} diff --git a/pkg/hostarch/hostarch_arm64_state_autogen.go b/pkg/hostarch/hostarch_arm64_state_autogen.go new file mode 100644 index 000000000..444542b30 --- /dev/null +++ b/pkg/hostarch/hostarch_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package hostarch diff --git a/pkg/hostarch/hostarch_state_autogen.go b/pkg/hostarch/hostarch_state_autogen.go new file mode 100644 index 000000000..32939e06b --- /dev/null +++ b/pkg/hostarch/hostarch_state_autogen.go @@ -0,0 +1,80 @@ +// automatically generated by stateify. + +package hostarch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (a *AccessType) StateTypeName() string { + return "pkg/hostarch.AccessType" +} + +func (a *AccessType) StateFields() []string { + return []string{ + "Read", + "Write", + "Execute", + } +} + +func (a *AccessType) beforeSave() {} + +// +checklocksignore +func (a *AccessType) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.Read) + stateSinkObject.Save(1, &a.Write) + stateSinkObject.Save(2, &a.Execute) +} + +func (a *AccessType) afterLoad() {} + +// +checklocksignore +func (a *AccessType) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.Read) + stateSourceObject.Load(1, &a.Write) + stateSourceObject.Load(2, &a.Execute) +} + +func (v *Addr) StateTypeName() string { + return "pkg/hostarch.Addr" +} + +func (v *Addr) StateFields() []string { + return nil +} + +func (r *AddrRange) StateTypeName() string { + return "pkg/hostarch.AddrRange" +} + +func (r *AddrRange) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (r *AddrRange) beforeSave() {} + +// +checklocksignore +func (r *AddrRange) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) +} + +func (r *AddrRange) afterLoad() {} + +// +checklocksignore +func (r *AddrRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) +} + +func init() { + state.Register((*AccessType)(nil)) + state.Register((*Addr)(nil)) + state.Register((*AddrRange)(nil)) +} diff --git a/pkg/hostarch/hostarch_unsafe_state_autogen.go b/pkg/hostarch/hostarch_unsafe_state_autogen.go new file mode 100644 index 000000000..419f24913 --- /dev/null +++ b/pkg/hostarch/hostarch_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostarch diff --git a/pkg/hostarch/hostarch_x86_state_autogen.go b/pkg/hostarch/hostarch_x86_state_autogen.go new file mode 100644 index 000000000..239b6ba6c --- /dev/null +++ b/pkg/hostarch/hostarch_x86_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 || 386 +// +build amd64 386 + +package hostarch diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD deleted file mode 100644 index 3f6eb07df..000000000 --- a/pkg/ilist/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_library( - name = "ilist", - srcs = [ - "interface_list.go", - ], - visibility = ["//visibility:public"], -) - -go_template_instance( - name = "interface_list", - out = "interface_list.go", - package = "ilist", - template = ":generic_list", - types = {}, -) - -# This list is used for benchmarking. -go_template_instance( - name = "test_list", - out = "test_list.go", - package = "ilist", - prefix = "direct", - template = ":generic_list", - types = { - "Element": "*direct", - "Linker": "*direct", - }, -) - -go_test( - name = "list_test", - size = "small", - srcs = [ - "list_test.go", - "test_list.go", - ], - library = ":ilist", -) - -go_template( - name = "generic_list", - srcs = [ - "list.go", - ], - opt_types = [ - "Element", - "ElementMapper", - "Linker", - ], - visibility = ["//visibility:public"], -) diff --git a/pkg/ilist/ilist_state_autogen.go b/pkg/ilist/ilist_state_autogen.go new file mode 100644 index 000000000..d329cb299 --- /dev/null +++ b/pkg/ilist/ilist_state_autogen.go @@ -0,0 +1,68 @@ +// automatically generated by stateify. + +package ilist + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *List) StateTypeName() string { + return "pkg/ilist.List" +} + +func (l *List) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *List) beforeSave() {} + +// +checklocksignore +func (l *List) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *List) afterLoad() {} + +// +checklocksignore +func (l *List) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *Entry) StateTypeName() string { + return "pkg/ilist.Entry" +} + +func (e *Entry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *Entry) beforeSave() {} + +// +checklocksignore +func (e *Entry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *Entry) afterLoad() {} + +// +checklocksignore +func (e *Entry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*List)(nil)) + state.Register((*Entry)(nil)) +} diff --git a/pkg/ilist/list.go b/pkg/ilist/interface_list.go index 557051d18..4d6062d7f 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/interface_list.go @@ -1,18 +1,3 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ilist provides the implementation of intrusive linked lists. package ilist // Linker is the interface that objects must implement if they want to be added diff --git a/pkg/ilist/list_test.go b/pkg/ilist/list_test.go deleted file mode 100644 index 3f9abfb56..000000000 --- a/pkg/ilist/list_test.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ilist - -import ( - "testing" -) - -type testEntry struct { - Entry - value int -} - -type direct struct { - directEntry - value int -} - -func verifyEquality(t *testing.T, entries []testEntry, l *List) { - t.Helper() - - i := 0 - for it := l.Front(); it != nil; it = it.Next() { - e := it.(*testEntry) - if e != &entries[i] { - t.Errorf("Wrong entry at index %d", i) - return - } - i++ - } - - if i != len(entries) { - t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i) - return - } - - i = 0 - for it := l.Back(); it != nil; it = it.Prev() { - e := it.(*testEntry) - if e != &entries[len(entries)-1-i] { - t.Errorf("Wrong entry at index %d", i) - return - } - i++ - } - - if i != len(entries) { - t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i) - return - } -} - -func TestZeroEmpty(t *testing.T) { - var l List - if l.Front() != nil { - t.Error("Front is non-nil") - } - if l.Back() != nil { - t.Error("Back is non-nil") - } -} - -func TestPushBack(t *testing.T) { - var l List - - // Test single entry insertion. - var entry testEntry - l.PushBack(&entry) - - e := l.Front().(*testEntry) - if e != &entry { - t.Error("Wrong entry returned") - } - - // Test inserting 100 entries. - l.Reset() - var entries [100]testEntry - for i := range entries { - l.PushBack(&entries[i]) - } - - verifyEquality(t, entries[:], &l) -} - -func TestPushFront(t *testing.T) { - var l List - - // Test single entry insertion. - var entry testEntry - l.PushFront(&entry) - - e := l.Front().(*testEntry) - if e != &entry { - t.Error("Wrong entry returned") - } - - // Test inserting 100 entries. - l.Reset() - var entries [100]testEntry - for i := range entries { - l.PushFront(&entries[len(entries)-1-i]) - } - - verifyEquality(t, entries[:], &l) -} - -func TestRemove(t *testing.T) { - // Remove entry from single-element list. - var l List - var entry testEntry - l.PushBack(&entry) - l.Remove(&entry) - if l.Front() != nil { - t.Error("List is empty") - } - - var entries [100]testEntry - - // Remove single element from lists of lengths 2 to 101. - for n := 1; n <= len(entries); n++ { - for extra := 0; extra <= n; extra++ { - l.Reset() - for i := 0; i < n; i++ { - if extra == i { - l.PushBack(&entry) - } - l.PushBack(&entries[i]) - } - if extra == n { - l.PushBack(&entry) - } - - l.Remove(&entry) - verifyEquality(t, entries[:n], &l) - } - } -} - -func TestReset(t *testing.T) { - var l List - - // Resetting list of one element. - l.PushBack(&testEntry{}) - if l.Front() == nil { - t.Error("List is empty") - } - - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } - - // Resetting list of 10 elements. - for i := 0; i < 10; i++ { - l.PushBack(&testEntry{}) - } - - if l.Front() == nil { - t.Error("List is empty") - } - - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } - - // Resetting empty list. - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } -} - -func BenchmarkIterateForward(b *testing.B) { - var l List - for i := 0; i < 1000000; i++ { - l.PushBack(&testEntry{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Front(); e != nil; e = e.Next() { - tmp += e.(*testEntry).value - } - } -} - -func BenchmarkIterateBackward(b *testing.B) { - var l List - for i := 0; i < 1000000; i++ { - l.PushBack(&testEntry{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Back(); e != nil; e = e.Prev() { - tmp += e.(*testEntry).value - } - } -} - -func BenchmarkDirectIterateForward(b *testing.B) { - var l directList - for i := 0; i < 1000000; i++ { - l.PushBack(&direct{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Front(); e != nil; e = e.Next() { - tmp += e.value - } - } -} - -func BenchmarkDirectIterateBackward(b *testing.B) { - var l directList - for i := 0; i < 1000000; i++ { - l.PushBack(&direct{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Back(); e != nil; e = e.Prev() { - tmp += e.value - } - } -} diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD deleted file mode 100644 index f84d03700..000000000 --- a/pkg/linewriter/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "linewriter", - srcs = ["linewriter.go"], - marshal = False, - stateify = False, - visibility = ["//visibility:public"], - deps = ["//pkg/sync"], -) - -go_test( - name = "linewriter_test", - srcs = ["linewriter_test.go"], - library = ":linewriter", -) diff --git a/pkg/linewriter/linewriter_test.go b/pkg/linewriter/linewriter_test.go deleted file mode 100644 index 96dc7e6e0..000000000 --- a/pkg/linewriter/linewriter_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package linewriter - -import ( - "bytes" - "testing" -) - -func TestWriter(t *testing.T) { - testCases := []struct { - input []string - want []string - }{ - { - input: []string{"1\n", "2\n"}, - want: []string{"1", "2"}, - }, - { - input: []string{"1\n", "\n", "2\n"}, - want: []string{"1", "", "2"}, - }, - { - input: []string{"1\n2\n", "3\n"}, - want: []string{"1", "2", "3"}, - }, - { - input: []string{"1", "2\n"}, - want: []string{"12"}, - }, - { - // Data with no newline yet is omitted. - input: []string{"1\n", "2\n", "3"}, - want: []string{"1", "2"}, - }, - } - - for _, c := range testCases { - var lines [][]byte - - w := NewWriter(func(p []byte) { - // We must not retain p, so we must make a copy. - b := make([]byte, len(p)) - copy(b, p) - - lines = append(lines, b) - }) - - for _, in := range c.input { - n, err := w.Write([]byte(in)) - if err != nil { - t.Errorf("Write(%q) err got %v want nil (case %+v)", in, err, c) - } - if n != len(in) { - t.Errorf("Write(%q) b got %d want %d (case %+v)", in, n, len(in), c) - } - } - - if len(lines) != len(c.want) { - t.Errorf("len(lines) got %d want %d (case %+v)", len(lines), len(c.want), c) - } - - for i := range lines { - if !bytes.Equal(lines[i], []byte(c.want[i])) { - t.Errorf("item %d got %q want %q (case %+v)", i, lines[i], c.want[i], c) - } - } - } -} diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md deleted file mode 100644 index 51d0d40e5..000000000 --- a/pkg/lisafs/README.md +++ /dev/null @@ -1,363 +0,0 @@ -# Replacing 9P - -## Background - -The Linux filesystem model consists of the following key aspects (modulo mounts, -which are outside the scope of this discussion): - -- A `struct inode` represents a "filesystem object", such as a directory or a - regular file. "Filesystem object" is most precisely defined by the practical - properties of an inode, such as an immutable type (regular file, directory, - symbolic link, etc.) and its independence from the path originally used to - obtain it. - -- A `struct dentry` represents a node in a filesystem tree. Semantically, each - dentry is immutably associated with an inode representing the filesystem - object at that position. (Linux implements optimizations involving reuse of - unreferenced dentries, which allows their associated inodes to change, but - this is outside the scope of this discussion.) - -- A `struct file` represents an open file description (hereafter FD) and is - needed to perform I/O. Each FD is immutably associated with the dentry - through which it was opened. - -The current gVisor virtual filesystem implementation (hereafter VFS1) closely -imitates the Linux design: - -- `struct inode` => `fs.Inode` - -- `struct dentry` => `fs.Dirent` - -- `struct file` => `fs.File` - -gVisor accesses most external filesystems through a variant of the 9P2000.L -protocol, including extensions for performance (`walkgetattr`) and for features -not supported by vanilla 9P2000.L (`flushf`, `lconnect`). The 9P protocol family -is inode-based; 9P fids represent a file (equivalently "file system object"), -and the protocol is structured around alternatively obtaining fids to represent -files (with `walk` and, in gVisor, `walkgetattr`) and performing operations on -those fids. - -In the sections below, a **shared** filesystem is a filesystem that is *mutably* -accessible by multiple concurrent clients, such that a **non-shared** filesystem -is a filesystem that is either read-only or accessible by only a single client. - -## Problems - -### Serialization of Path Component RPCs - -Broadly speaking, VFS1 traverses each path component in a pathname, alternating -between verifying that each traversed dentry represents an inode that represents -a searchable directory and moving to the next dentry in the path. - -In the context of a remote filesystem, the structure of this traversal means -that - modulo caching - a path involving N components requires at least N-1 -*sequential* RPCs to obtain metadata for intermediate directories, incurring -significant latency. (In vanilla 9P2000.L, 2(N-1) RPCs are required: N-1 `walk` -and N-1 `getattr`. We added the `walkgetattr` RPC to reduce this overhead.) On -non-shared filesystems, this overhead is primarily significant during -application startup; caching mitigates much of this overhead at steady state. On -shared filesystems, where correct caching requires revalidation (requiring RPCs -for each revalidated directory anyway), this overhead is consistently ruinous. - -### Inefficient RPCs - -9P is not exceptionally economical with RPCs in general. In addition to the -issue described above: - -- Opening an existing file in 9P involves at least 2 RPCs: `walk` to produce - an unopened fid representing the file, and `lopen` to open the fid. - -- Creating a file also involves at least 2 RPCs: `walk` to produce an unopened - fid representing the parent directory, and `lcreate` to create the file and - convert the fid to an open fid representing the created file. In practice, - both the Linux and gVisor 9P clients expect to have an unopened fid for the - created file (necessitating an additional `walk`), as well as attributes for - the created file (necessitating an additional `getattr`), for a total of 4 - RPCs. (In a shared filesystem, where whether a file already exists can - change between RPCs, a correct implementation of `open(O_CREAT)` would have - to alternate between these two paths (plus `clunk`ing the temporary fid - between alternations, since the nature of the `fid` differs between the two - paths). Neither Linux nor gVisor implement the required alternation, so - `open(O_CREAT)` without `O_EXCL` can spuriously fail with `EEXIST` on both.) - -- Closing (`clunk`ing) a fid requires an RPC. VFS1 issues this RPC - asynchronously in an attempt to reduce critical path latency, but scheduling - overhead makes this not clearly advantageous in practice. - -- `read` and `readdir` can return partial reads without a way to indicate EOF, - necessitating an additional final read to detect EOF. - -- Operations that affect filesystem state do not consistently return updated - filesystem state. In gVisor, the client implementation attempts to handle - this by tracking what it thinks updated state "should" be; this is complex, - and especially brittle for timestamps (which are often not arbitrarily - settable). In Linux, the client implemtation invalidates cached metadata - whenever it performs such an operation, and reloads it when a dentry - corresponding to an inode with no valid cached metadata is revalidated; this - is simple, but necessitates an additional `getattr`. - -### Dentry/Inode Ambiguity - -As noted above, 9P's documentation tends to imply that unopened fids represent -an inode. In practice, most filesystem APIs present very limited interfaces for -working with inodes at best, such that the interpretation of unopened fids -varies: - -- Linux's 9P client associates unopened fids with (dentry, uid) pairs. When - caching is enabled, it also associates each inode with the first fid opened - writably that references that inode, in order to support page cache - writeback. - -- gVisor's 9P client associates unopened fids with inodes, and also caches - opened fids in inodes in a manner similar to Linux. - -- The runsc fsgofer associates unopened fids with both "dentries" (host - filesystem paths) and "inodes" (host file descriptors); which is used - depends on the operation invoked on the fid. - -For non-shared filesystems, this confusion has resulted in correctness issues -that are (in gVisor) currently handled by a number of coarse-grained locks that -serialize renames with all other filesystem operations. For shared filesystems, -this means inconsistent behavior in the presence of concurrent mutation. - -## Design - -Almost all Linux filesystem syscalls describe filesystem resources in one of two -ways: - -- Path-based: A filesystem position is described by a combination of a - starting position and a sequence of path components relative to that - position, where the starting position is one of: - - - The VFS root (defined by mount namespace and chroot), for absolute paths - - - The VFS position of an existing FD, for relative paths passed to `*at` - syscalls (e.g. `statat`) - - - The current working directory, for relative paths passed to non-`*at` - syscalls and `*at` syscalls with `AT_FDCWD` - -- File-description-based: A filesystem object is described by an existing FD, - passed to a `f*` syscall (e.g. `fstat`). - -Many of our issues with 9P arise from its (and VFS') interposition of a model -based on inodes between the filesystem syscall API and filesystem -implementations. We propose to replace 9P with a protocol that does not feature -inodes at all, and instead closely follows the filesystem syscall API by -featuring only path-based and FD-based operations, with minimal deviations as -necessary to ameliorate deficiencies in the syscall interface (see below). This -approach addresses the issues described above: - -- Even on shared filesystems, most application filesystem syscalls are - translated to a single RPC (possibly excepting special cases described - below), which is a logical lower bound. - -- The behavior of application syscalls on shared filesystems is - straightforwardly predictable: path-based syscalls are translated to - path-based RPCs, which will re-lookup the file at that path, and FD-based - syscalls are translated to FD-based RPCs, which use an existing open file - without performing another lookup. (This is at least true on gofers that - proxy the host local filesystem; other filesystems that lack support for - e.g. certain operations on FDs may have different behavior, but this - divergence is at least still predictable and inherent to the underlying - filesystem implementation.) - -Note that this approach is only feasible in gVisor's next-generation virtual -filesystem (VFS2), which does not assume the existence of inodes and allows the -remote filesystem client to translate whole path-based syscalls into RPCs. Thus -one of the unavoidable tradeoffs associated with such a protocol vs. 9P is the -inability to construct a Linux client that is performance-competitive with -gVisor. - -### File Permissions - -Many filesystem operations are side-effectual, such that file permissions must -be checked before such operations take effect. The simplest approach to file -permission checking is for the sentry to obtain permissions from the remote -filesystem, then apply permission checks in the sentry before performing the -application-requested operation. However, this requires an additional RPC per -application syscall (which can't be mitigated by caching on shared filesystems). -Alternatively, we may delegate file permission checking to gofers. In general, -file permission checks depend on the following properties of the accessor: - -- Filesystem UID/GID - -- Supplementary GIDs - -- Effective capabilities in the accessor's user namespace (i.e. the accessor's - effective capability set) - -- All UIDs and GIDs mapped in the accessor's user namespace (which determine - if the accessor's capabilities apply to accessed files) - -We may choose to delay implementation of file permission checking delegation, -although this is potentially costly since it doubles the number of required RPCs -for most operations on shared filesystems. We may also consider compromise -options, such as only delegating file permission checks for accessors in the -root user namespace. - -### Symbolic Links - -gVisor usually interprets symbolic link targets in its VFS rather than on the -filesystem containing the symbolic link; thus e.g. a symlink to -"/proc/self/maps" on a remote filesystem resolves to said file in the sentry's -procfs rather than the host's. This implies that: - -- Remote filesystem servers that proxy filesystems supporting symlinks must - check if each path component is a symlink during path traversal. - -- Absolute symlinks require that the sentry restart the operation at its - contextual VFS root (which is task-specific and may not be on a remote - filesystem at all), so if a remote filesystem server encounters an absolute - symlink during path traversal on behalf of a path-based operation, it must - terminate path traversal and return the symlink target. - -- Relative symlinks begin target resolution in the parent directory of the - symlink, so in theory most relative symlinks can be handled automatically - during the path traversal that encounters the symlink, provided that said - traversal is supplied with the number of remaining symlinks before `ELOOP`. - However, the new path traversed by the symlink target may cross VFS mount - boundaries, such that it's only safe for remote filesystem servers to - speculatively follow relative symlinks for side-effect-free operations such - as `stat` (where the sentry can simply ignore results that are inapplicable - due to crossing mount boundaries). We may choose to delay implementation of - this feature, at the cost of an additional RPC per relative symlink (note - that even if the symlink target crosses a mount boundary, the sentry will - need to `stat` the path to the mount boundary to confirm that each traversed - component is an accessible directory); until it is implemented, relative - symlinks may be handled like absolute symlinks, by terminating path - traversal and returning the symlink target. - -The possibility of symlinks (and the possibility of a compromised sentry) means -that the sentry may issue RPCs with paths that, in the absence of symlinks, -would traverse beyond the root of the remote filesystem. For example, the sentry -may issue an RPC with a path like "/foo/../..", on the premise that if "/foo" is -a symlink then the resulting path may be elsewhere on the remote filesystem. To -handle this, path traversal must also track its current depth below the remote -filesystem root, and terminate path traversal if it would ascend beyond this -point. - -### Path Traversal - -Since path-based VFS operations will translate to path-based RPCs, filesystem -servers will need to handle path traversal. From the perspective of a given -filesystem implementation in the server, there are two basic approaches to path -traversal: - -- Inode-walk: For each path component, obtain a handle to the underlying - filesystem object (e.g. with `open(O_PATH)`), check if that object is a - symlink (as described above) and that that object is accessible by the - caller (e.g. with `fstat()`), then continue to the next path component (e.g. - with `openat()`). This ensures that the checked filesystem object is the one - used to obtain the next object in the traversal, which is intuitively - appealing. However, while this approach works for host local filesystems, it - requires features that are not widely supported by other filesystems. - -- Path-walk: For each path component, use a path-based operation to determine - if the filesystem object currently referred to by that path component is a - symlink / is accessible. This is highly portable, but suffers from quadratic - behavior (at the level of the underlying filesystem implementation, the - first path component will be traversed a number of times equal to the number - of path components in the path). - -The implementation should support either option by delegating path traversal to -filesystem implementations within the server (like VFS and the remote filesystem -protocol itself), as inode-walking is still safe, efficient, amenable to FD -caching, and implementable on non-shared host local filesystems (a sufficiently -common case as to be worth considering in the design). - -Both approaches are susceptible to race conditions that may permit sandboxed -filesystem escapes: - -- Under inode-walk, a malicious application may cause a directory to be moved - (with `rename`) during path traversal, such that the filesystem - implementation incorrectly determines whether subsequent inodes are located - in paths that should be visible to sandboxed applications. - -- Under path-walk, a malicious application may cause a non-symlink file to be - replaced with a symlink during path traversal, such that following path - operations will incorrectly follow the symlink. - -Both race conditions can, to some extent, be mitigated in filesystem server -implementations by synchronizing path traversal with the hazardous operations in -question. However, shared filesystems are frequently used to share data between -sandboxed and unsandboxed applications in a controlled way, and in some cases a -malicious sandboxed application may be able to take advantage of a hazardous -filesystem operation performed by an unsandboxed application. In some cases, -filesystem features may be available to ensure safety even in such cases (e.g. -[the new openat2() syscall](https://man7.org/linux/man-pages/man2/openat2.2.html)), -but it is not clear how to solve this problem in general. (Note that this issue -is not specific to our design; rather, it is a fundamental limitation of -filesystem sandboxing.) - -### Filesystem Multiplexing - -A given sentry may need to access multiple distinct remote filesystems (e.g. -different volumes for a given container). In many cases, there is no advantage -to serving these filesystems from distinct filesystem servers, or accessing them -through distinct connections (factors such as maximum RPC concurrency should be -based on available host resources). Therefore, the protocol should support -multiplexing of distinct filesystem trees within a single session. 9P supports -this by allowing multiple calls to the `attach` RPC to produce fids representing -distinct filesystem trees, but this is somewhat clunky; we propose a much -simpler mechanism wherein each message that conveys a path also conveys a -numeric filesystem ID that identifies a filesystem tree. - -## Alternatives Considered - -### Additional Extensions to 9P - -There are at least three conceptual aspects to 9P: - -- Wire format: messages with a 4-byte little-endian size prefix, strings with - a 2-byte little-endian size prefix, etc. Whether the wire format is worth - retaining is unclear; in particular, it's unclear that the 9P wire format - has a significant advantage over protobufs, which are substantially easier - to extend. Note that the official Go protobuf implementation is widely known - to suffer from a significant number of performance deficiencies, so if we - choose to switch to protobuf, we may need to use an alternative toolchain - such as `gogo/protobuf` (which is also widely used in the Go ecosystem, e.g. - by Kubernetes). - -- Filesystem model: fids, qids, etc. Discarding this is one of the motivations - for this proposal. - -- RPCs: Twalk, Tlopen, etc. In addition to previously-described - inefficiencies, most of these are dependent on the filesystem model and - therefore must be discarded. - -### FUSE - -The FUSE (Filesystem in Userspace) protocol is frequently used to provide -arbitrary userspace filesystem implementations to a host Linux kernel. -Unfortunately, FUSE is also inode-based, and therefore doesn't address any of -the problems we have with 9P. - -### virtio-fs - -virtio-fs is an ongoing project aimed at improving Linux VM filesystem -performance when accessing Linux host filesystems (vs. virtio-9p). In brief, it -is based on: - -- Using a FUSE client in the guest that communicates over virtio with a FUSE - server in the host. - -- Using DAX to map the host page cache into the guest. - -- Using a file metadata table in shared memory to avoid VM exits for metadata - updates. - -None of these improvements seem applicable to gVisor: - -- As explained above, FUSE is still inode-based, so it is still susceptible to - most of the problems we have with 9P. - -- Our use of host file descriptors already allows us to leverage the host page - cache for file contents. - -- Our need for shared filesystem coherence is usually based on a user - requirement that an out-of-sandbox filesystem mutation is guaranteed to be - visible by all subsequent observations from within the sandbox, or vice - versa; it's not clear that this can be guaranteed without a synchronous - signaling mechanism like an RPC. diff --git a/pkg/log/BUILD b/pkg/log/BUILD deleted file mode 100644 index 3ed6aba5c..000000000 --- a/pkg/log/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "log", - srcs = [ - "glog.go", - "json.go", - "json_k8s.go", - "log.go", - ], - marshal = False, - stateify = False, - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/linewriter", - "//pkg/sync", - ], -) - -go_test( - name = "log_test", - size = "small", - srcs = [ - "json_test.go", - "log_test.go", - ], - library = ":log", -) diff --git a/pkg/log/json_test.go b/pkg/log/json_test.go deleted file mode 100644 index f25224fe1..000000000 --- a/pkg/log/json_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -import ( - "encoding/json" - "testing" -) - -// Tests that Level can marshal/unmarshal properly. -func TestLevelMarshal(t *testing.T) { - lvs := []Level{Warning, Info, Debug} - for _, lv := range lvs { - bs, err := lv.MarshalJSON() - if err != nil { - t.Errorf("error marshaling %v: %v", lv, err) - } - var lv2 Level - if err := lv2.UnmarshalJSON(bs); err != nil { - t.Errorf("error unmarshaling %v: %v", bs, err) - } - if lv != lv2 { - t.Errorf("marshal/unmarshal level got %v wanted %v", lv2, lv) - } - } -} - -// Test that integers can be properly unmarshaled. -func TestUnmarshalFromInt(t *testing.T) { - tcs := []struct { - i int - want Level - }{ - {0, Warning}, - {1, Info}, - {2, Debug}, - } - - for _, tc := range tcs { - j, err := json.Marshal(tc.i) - if err != nil { - t.Errorf("error marshaling %v: %v", tc.i, err) - } - var lv Level - if err := lv.UnmarshalJSON(j); err != nil { - t.Errorf("error unmarshaling %v: %v", j, err) - } - if lv != tc.want { - t.Errorf("marshal/unmarshal %v got %v want %v", tc.i, lv, tc.want) - } - } -} diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go deleted file mode 100644 index 9ff18559b..000000000 --- a/pkg/log/log_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -import ( - "fmt" - "strings" - "testing" -) - -type testWriter struct { - lines []string - fail bool - limit int -} - -func (w *testWriter) Write(bytes []byte) (int, error) { - if w.fail { - return 0, fmt.Errorf("simulated failure") - } - if w.limit > 0 && len(w.lines) >= w.limit { - return len(bytes), nil - } - w.lines = append(w.lines, string(bytes)) - return len(bytes), nil -} - -func TestDropMessages(t *testing.T) { - tw := &testWriter{} - w := Writer{Next: tw} - if _, err := w.Write([]byte("line 1\n")); err != nil { - t.Fatalf("Write failed, err: %v", err) - } - - tw.fail = true - if _, err := w.Write([]byte("error\n")); err == nil { - t.Fatalf("Write should have failed") - } - if _, err := w.Write([]byte("error\n")); err == nil { - t.Fatalf("Write should have failed") - } - - fmt.Printf("writer: %#v\n", &w) - - tw.fail = false - if _, err := w.Write([]byte("line 2\n")); err != nil { - t.Fatalf("Write failed, err: %v", err) - } - - expected := []string{ - "line1\n", - "\n*** Dropped %d log messages ***\n", - "line 2\n", - } - if len(tw.lines) != len(expected) { - t.Fatalf("Writer should have logged %d lines, got: %v, expected: %v", len(expected), tw.lines, expected) - } - for i, l := range tw.lines { - if l == expected[i] { - t.Fatalf("line %d doesn't match, got: %v, expected: %v", i, l, expected[i]) - } - } -} - -func TestCaller(t *testing.T) { - tw := &testWriter{} - e := GoogleEmitter{Writer: &Writer{Next: tw}} - bl := &BasicLogger{ - Emitter: e, - Level: Debug, - } - bl.Debugf("testing...\n") // Just for file + line. - if len(tw.lines) != 1 { - t.Errorf("expected 1 line, got %d", len(tw.lines)) - } - if !strings.Contains(tw.lines[0], "log_test.go") { - t.Errorf("expected log_test.go, got %q", tw.lines[0]) - } -} - -func BenchmarkGoogleLogging(b *testing.B) { - tw := &testWriter{ - limit: 1, // Only record one message. - } - e := GoogleEmitter{Writer: &Writer{Next: tw}} - bl := &BasicLogger{ - Emitter: e, - Level: Debug, - } - for i := 0; i < b.N; i++ { - bl.Debugf("hello %d, %d, %d", 1, 2, 3) - } -} diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD deleted file mode 100644 index 7a5002176..000000000 --- a/pkg/marshal/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "marshal", - srcs = [ - "marshal.go", - "marshal_impl_util.go", - "util.go", - ], - visibility = [ - "//:sandbox", - ], - deps = ["//pkg/hostarch"], -) diff --git a/pkg/marshal/marshal_state_autogen.go b/pkg/marshal/marshal_state_autogen.go new file mode 100644 index 000000000..a0a953158 --- /dev/null +++ b/pkg/marshal/marshal_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package marshal diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD deleted file mode 100644 index 6e5ce136d..000000000 --- a/pkg/marshal/primitive/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "primitive", - srcs = [ - "primitive.go", - ], - marshal = True, - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/hostarch", - "//pkg/marshal", - ], -) diff --git a/pkg/marshal/primitive/primitive_abi_autogen_unsafe.go b/pkg/marshal/primitive/primitive_abi_autogen_unsafe.go new file mode 100644 index 000000000..51e9e7144 --- /dev/null +++ b/pkg/marshal/primitive/primitive_abi_autogen_unsafe.go @@ -0,0 +1,1360 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package primitive + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*Int16)(nil) +var _ marshal.Marshallable = (*Int32)(nil) +var _ marshal.Marshallable = (*Int64)(nil) +var _ marshal.Marshallable = (*Int8)(nil) +var _ marshal.Marshallable = (*Uint16)(nil) +var _ marshal.Marshallable = (*Uint32)(nil) +var _ marshal.Marshallable = (*Uint64)(nil) +var _ marshal.Marshallable = (*Uint8)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (i *Int16) SizeBytes() int { + return 2 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Int16) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(*i)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Int16) UnmarshalBytes(src []byte) { + *i = Int16(int16(hostarch.ByteOrder.Uint16(src[:2]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Int16) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Int16) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Int16) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Int16) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Int16) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Int16) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Int16) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyInt16SliceIn copies in a slice of int16 objects from the task's memory. +//go:nosplit +func CopyInt16SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []int16) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int16)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyInt16SliceOut copies a slice of int16 objects to the task's memory. +//go:nosplit +func CopyInt16SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []int16) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int16)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeInt16Slice is like Int16.MarshalUnsafe, but for a []Int16. +func MarshalUnsafeInt16Slice(src []Int16, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int16)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeInt16Slice is like Int16.UnmarshalUnsafe, but for a []Int16. +func UnmarshalUnsafeInt16Slice(dst []Int16, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int16)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (i *Int32) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Int32) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*i)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Int32) UnmarshalBytes(src []byte) { + *i = Int32(int32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Int32) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Int32) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Int32) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Int32) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Int32) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Int32) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Int32) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyInt32SliceIn copies in a slice of int32 objects from the task's memory. +//go:nosplit +func CopyInt32SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []int32) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int32)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyInt32SliceOut copies a slice of int32 objects to the task's memory. +//go:nosplit +func CopyInt32SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []int32) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int32)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeInt32Slice is like Int32.MarshalUnsafe, but for a []Int32. +func MarshalUnsafeInt32Slice(src []Int32, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int32)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeInt32Slice is like Int32.UnmarshalUnsafe, but for a []Int32. +func UnmarshalUnsafeInt32Slice(dst []Int32, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int32)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (i *Int64) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Int64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(*i)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Int64) UnmarshalBytes(src []byte) { + *i = Int64(int64(hostarch.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Int64) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Int64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Int64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Int64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Int64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Int64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Int64) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyInt64SliceIn copies in a slice of int64 objects from the task's memory. +//go:nosplit +func CopyInt64SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []int64) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int64)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyInt64SliceOut copies a slice of int64 objects to the task's memory. +//go:nosplit +func CopyInt64SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []int64) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int64)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeInt64Slice is like Int64.MarshalUnsafe, but for a []Int64. +func MarshalUnsafeInt64Slice(src []Int64, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int64)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeInt64Slice is like Int64.UnmarshalUnsafe, but for a []Int64. +func UnmarshalUnsafeInt64Slice(dst []Int64, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int64)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (i *Int8) SizeBytes() int { + return 1 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (i *Int8) MarshalBytes(dst []byte) { + dst[0] = byte(*i) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (i *Int8) UnmarshalBytes(src []byte) { + *i = Int8(int8(src[0])) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (i *Int8) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (i *Int8) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(i), uintptr(i.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (i *Int8) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(i), unsafe.Pointer(&src[0]), uintptr(i.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (i *Int8) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (i *Int8) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return i.CopyOutN(cc, addr, i.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (i *Int8) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (i *Int8) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(i))) + hdr.Len = i.SizeBytes() + hdr.Cap = i.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that i + // must live until the use above. + runtime.KeepAlive(i) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyInt8SliceIn copies in a slice of int8 objects from the task's memory. +//go:nosplit +func CopyInt8SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []int8) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int8)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyInt8SliceOut copies a slice of int8 objects to the task's memory. +//go:nosplit +func CopyInt8SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []int8) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int8)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeInt8Slice is like Int8.MarshalUnsafe, but for a []Int8. +func MarshalUnsafeInt8Slice(src []Int8, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Int8)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeInt8Slice is like Int8.UnmarshalUnsafe, but for a []Int8. +func UnmarshalUnsafeInt8Slice(dst []Int8, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Int8)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (u *Uint16) SizeBytes() int { + return 2 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *Uint16) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(*u)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *Uint16) UnmarshalBytes(src []byte) { + *u = Uint16(uint16(hostarch.ByteOrder.Uint16(src[:2]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *Uint16) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *Uint16) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *Uint16) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *Uint16) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *Uint16) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *Uint16) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *Uint16) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyUint16SliceIn copies in a slice of uint16 objects from the task's memory. +//go:nosplit +func CopyUint16SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []uint16) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint16)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyUint16SliceOut copies a slice of uint16 objects to the task's memory. +//go:nosplit +func CopyUint16SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []uint16) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint16)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeUint16Slice is like Uint16.MarshalUnsafe, but for a []Uint16. +func MarshalUnsafeUint16Slice(src []Uint16, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint16)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeUint16Slice is like Uint16.UnmarshalUnsafe, but for a []Uint16. +func UnmarshalUnsafeUint16Slice(dst []Uint16, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint16)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (u *Uint32) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *Uint32) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*u)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *Uint32) UnmarshalBytes(src []byte) { + *u = Uint32(uint32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *Uint32) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *Uint32) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *Uint32) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *Uint32) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *Uint32) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *Uint32) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *Uint32) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyUint32SliceIn copies in a slice of uint32 objects from the task's memory. +//go:nosplit +func CopyUint32SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []uint32) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint32)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyUint32SliceOut copies a slice of uint32 objects to the task's memory. +//go:nosplit +func CopyUint32SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []uint32) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint32)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeUint32Slice is like Uint32.MarshalUnsafe, but for a []Uint32. +func MarshalUnsafeUint32Slice(src []Uint32, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint32)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeUint32Slice is like Uint32.UnmarshalUnsafe, but for a []Uint32. +func UnmarshalUnsafeUint32Slice(dst []Uint32, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint32)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (u *Uint64) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *Uint64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(*u)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *Uint64) UnmarshalBytes(src []byte) { + *u = Uint64(uint64(hostarch.ByteOrder.Uint64(src[:8]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *Uint64) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *Uint64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *Uint64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *Uint64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *Uint64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *Uint64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *Uint64) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyUint64SliceIn copies in a slice of uint64 objects from the task's memory. +//go:nosplit +func CopyUint64SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []uint64) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint64)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyUint64SliceOut copies a slice of uint64 objects to the task's memory. +//go:nosplit +func CopyUint64SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []uint64) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint64)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeUint64Slice is like Uint64.MarshalUnsafe, but for a []Uint64. +func MarshalUnsafeUint64Slice(src []Uint64, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint64)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeUint64Slice is like Uint64.UnmarshalUnsafe, but for a []Uint64. +func UnmarshalUnsafeUint64Slice(dst []Uint64, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint64)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (u *Uint8) SizeBytes() int { + return 1 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *Uint8) MarshalBytes(dst []byte) { + dst[0] = byte(*u) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *Uint8) UnmarshalBytes(src []byte) { + *u = Uint8(uint8(src[0])) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *Uint8) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *Uint8) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *Uint8) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *Uint8) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *Uint8) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *Uint8) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *Uint8) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyUint8SliceIn copies in a slice of uint8 objects from the task's memory. +//go:nosplit +func CopyUint8SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []uint8) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint8)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyUint8SliceOut copies a slice of uint8 objects to the task's memory. +//go:nosplit +func CopyUint8SliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []uint8) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint8)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeUint8Slice is like Uint8.MarshalUnsafe, but for a []Uint8. +func MarshalUnsafeUint8Slice(src []Uint8, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*Uint8)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeUint8Slice is like Uint8.UnmarshalUnsafe, but for a []Uint8. +func UnmarshalUnsafeUint8Slice(dst []Uint8, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*Uint8)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + diff --git a/pkg/marshal/primitive/primitive_state_autogen.go b/pkg/marshal/primitive/primitive_state_autogen.go new file mode 100644 index 000000000..f9db3a918 --- /dev/null +++ b/pkg/marshal/primitive/primitive_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package primitive diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD deleted file mode 100644 index bea595286..000000000 --- a/pkg/memutil/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "memutil", - srcs = [ - "memfd_linux_unsafe.go", - "memutil.go", - "mmap.go", - ], - visibility = ["//visibility:public"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/memutil/memutil_linux_unsafe_state_autogen.go b/pkg/memutil/memutil_linux_unsafe_state_autogen.go new file mode 100644 index 000000000..df954e7f4 --- /dev/null +++ b/pkg/memutil/memutil_linux_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package memutil diff --git a/pkg/memutil/memutil_state_autogen.go b/pkg/memutil/memutil_state_autogen.go new file mode 100644 index 000000000..7e42b8f81 --- /dev/null +++ b/pkg/memutil/memutil_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package memutil diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD deleted file mode 100644 index dcd6c3bf5..000000000 --- a/pkg/merkletree/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "merkletree", - srcs = ["merkletree.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/hostarch", - ], -) - -go_test( - name = "merkletree_test", - srcs = ["merkletree_test.go"], - library = ":merkletree", - deps = [ - "//pkg/abi/linux", - "//pkg/hostarch", - ], -) diff --git a/pkg/merkletree/merkletree_state_autogen.go b/pkg/merkletree/merkletree_state_autogen.go new file mode 100644 index 000000000..e5670cde4 --- /dev/null +++ b/pkg/merkletree/merkletree_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package merkletree diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go deleted file mode 100644 index 1447fd139..000000000 --- a/pkg/merkletree/merkletree_test.go +++ /dev/null @@ -1,905 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package merkletree - -import ( - "bytes" - "errors" - "fmt" - "io" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - - "gvisor.dev/gvisor/pkg/hostarch" -) - -func TestLayout(t *testing.T) { - testCases := []struct { - name string - dataSize int64 - hashAlgorithms int - dataAndTreeInSameFile bool - expectedDigestSize int64 - expectedLevelOffset []int64 - }{ - { - name: "SmallSizeSHA256SeparateFile", - dataSize: 100, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedDigestSize: 32, - expectedLevelOffset: []int64{0}, - }, - { - name: "SmallSizeSHA512SeparateFile", - dataSize: 100, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedDigestSize: 64, - expectedLevelOffset: []int64{0}, - }, - { - name: "SmallSizeSHA256SameFile", - dataSize: 100, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedDigestSize: 32, - expectedLevelOffset: []int64{hostarch.PageSize}, - }, - { - name: "SmallSizeSHA512SameFile", - dataSize: 100, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedDigestSize: 64, - expectedLevelOffset: []int64{hostarch.PageSize}, - }, - { - name: "MiddleSizeSHA256SeparateFile", - dataSize: 1000000, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedDigestSize: 32, - expectedLevelOffset: []int64{0, 2 * hostarch.PageSize, 3 * hostarch.PageSize}, - }, - { - name: "MiddleSizeSHA512SeparateFile", - dataSize: 1000000, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedDigestSize: 64, - expectedLevelOffset: []int64{0, 4 * hostarch.PageSize, 5 * hostarch.PageSize}, - }, - { - name: "MiddleSizeSHA256SameFile", - dataSize: 1000000, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedDigestSize: 32, - expectedLevelOffset: []int64{245 * hostarch.PageSize, 247 * hostarch.PageSize, 248 * hostarch.PageSize}, - }, - { - name: "MiddleSizeSHA512SameFile", - dataSize: 1000000, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedDigestSize: 64, - expectedLevelOffset: []int64{245 * hostarch.PageSize, 249 * hostarch.PageSize, 250 * hostarch.PageSize}, - }, - { - name: "LargeSizeSHA256SeparateFile", - dataSize: 4096 * int64(hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedDigestSize: 32, - expectedLevelOffset: []int64{0, 32 * hostarch.PageSize, 33 * hostarch.PageSize}, - }, - { - name: "LargeSizeSHA512SeparateFile", - dataSize: 4096 * int64(hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedDigestSize: 64, - expectedLevelOffset: []int64{0, 64 * hostarch.PageSize, 65 * hostarch.PageSize}, - }, - { - name: "LargeSizeSHA256SameFile", - dataSize: 4096 * int64(hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedDigestSize: 32, - expectedLevelOffset: []int64{4096 * hostarch.PageSize, 4128 * hostarch.PageSize, 4129 * hostarch.PageSize}, - }, - { - name: "LargeSizeSHA512SameFile", - dataSize: 4096 * int64(hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedDigestSize: 64, - expectedLevelOffset: []int64{4096 * hostarch.PageSize, 4160 * hostarch.PageSize, 4161 * hostarch.PageSize}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) - if err != nil { - t.Fatalf("Failed to InitLayout: %v", err) - } - if l.blockSize != int64(hostarch.PageSize) { - t.Errorf("Got blockSize %d, want %d", l.blockSize, hostarch.PageSize) - } - if l.digestSize != tc.expectedDigestSize { - t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) - } - if l.numLevels() != len(tc.expectedLevelOffset) { - t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) - } - for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ { - if l.levelOffset[i] != tc.expectedLevelOffset[i] { - t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) - } - } - }) - } -} - -const ( - defaultName = "merkle_test" - defaultMode = 0644 - defaultUID = 0 - defaultGID = 0 - defaultSymlinkPath = "merkle_test_link" - defaultHashAlgorithm = linux.FS_VERITY_HASH_ALG_SHA256 -) - -// bytesReadWriter is used to read from/write to/seek in a byte array. Unlike -// bytes.Buffer, it keeps the whole buffer during read so that it can be reused. -type bytesReadWriter struct { - // bytes contains the underlying byte array. - bytes []byte - // readPos is the currently location for Read. Write always appends to - // the end of the array. - readPos int -} - -func (brw *bytesReadWriter) Write(p []byte) (int, error) { - brw.bytes = append(brw.bytes, p...) - return len(p), nil -} - -func (brw *bytesReadWriter) ReadAt(p []byte, off int64) (int, error) { - bytesRead := copy(p, brw.bytes[off:]) - if bytesRead == 0 { - return bytesRead, io.EOF - } - return bytesRead, nil -} - -func TestGenerate(t *testing.T) { - // The input data has size dataSize. It starts with the data in startWith, - // and all other bytes are zeroes. - testCases := []struct { - name string - data []byte - hashAlgorithms int - dataAndTreeInSameFile bool - expectedHash []byte - }{ - { - name: "OnePageZeroesSHA256SeparateFile", - data: bytes.Repeat([]byte{0}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedHash: []byte{78, 38, 225, 107, 61, 246, 26, 6, 71, 163, 254, 97, 112, 200, 87, 232, 190, 87, 231, 160, 119, 124, 61, 229, 49, 126, 90, 223, 134, 51, 77, 182}, - }, - { - name: "OnePageZeroesSHA256SameFile", - data: bytes.Repeat([]byte{0}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedHash: []byte{78, 38, 225, 107, 61, 246, 26, 6, 71, 163, 254, 97, 112, 200, 87, 232, 190, 87, 231, 160, 119, 124, 61, 229, 49, 126, 90, 223, 134, 51, 77, 182}, - }, - { - name: "OnePageZeroesSHA512SeparateFile", - data: bytes.Repeat([]byte{0}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedHash: []byte{221, 45, 182, 132, 61, 212, 227, 145, 150, 131, 98, 221, 195, 5, 89, 21, 188, 36, 250, 101, 85, 78, 197, 253, 193, 23, 74, 219, 28, 108, 77, 47, 65, 79, 123, 144, 50, 245, 109, 72, 71, 80, 24, 77, 158, 95, 242, 185, 109, 163, 105, 183, 67, 106, 55, 194, 223, 46, 12, 242, 165, 203, 172, 254}, - }, - { - name: "OnePageZeroesSHA512SameFile", - data: bytes.Repeat([]byte{0}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedHash: []byte{221, 45, 182, 132, 61, 212, 227, 145, 150, 131, 98, 221, 195, 5, 89, 21, 188, 36, 250, 101, 85, 78, 197, 253, 193, 23, 74, 219, 28, 108, 77, 47, 65, 79, 123, 144, 50, 245, 109, 72, 71, 80, 24, 77, 158, 95, 242, 185, 109, 163, 105, 183, 67, 106, 55, 194, 223, 46, 12, 242, 165, 203, 172, 254}, - }, - { - name: "MultiplePageZeroesSHA256SeparateFile", - data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedHash: []byte{131, 122, 73, 143, 4, 202, 193, 156, 218, 169, 196, 223, 70, 100, 117, 191, 241, 113, 134, 11, 229, 231, 105, 157, 156, 0, 66, 213, 122, 145, 174, 8}, - }, - { - name: "MultiplePageZeroesSHA256SameFile", - data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedHash: []byte{131, 122, 73, 143, 4, 202, 193, 156, 218, 169, 196, 223, 70, 100, 117, 191, 241, 113, 134, 11, 229, 231, 105, 157, 156, 0, 66, 213, 122, 145, 174, 8}, - }, - { - name: "MultiplePageZeroesSHA512SeparateFile", - data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedHash: []byte{211, 48, 232, 110, 240, 51, 99, 241, 123, 138, 42, 76, 94, 86, 59, 200, 3, 246, 137, 148, 189, 226, 111, 103, 146, 29, 12, 218, 40, 182, 33, 99, 193, 163, 238, 26, 184, 13, 165, 187, 68, 173, 139, 9, 208, 59, 0, 192, 180, 50, 221, 35, 43, 119, 194, 16, 64, 84, 116, 63, 158, 195, 194, 226}, - }, - { - name: "MultiplePageZeroesSHA512SameFile", - data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedHash: []byte{211, 48, 232, 110, 240, 51, 99, 241, 123, 138, 42, 76, 94, 86, 59, 200, 3, 246, 137, 148, 189, 226, 111, 103, 146, 29, 12, 218, 40, 182, 33, 99, 193, 163, 238, 26, 184, 13, 165, 187, 68, 173, 139, 9, 208, 59, 0, 192, 180, 50, 221, 35, 43, 119, 194, 16, 64, 84, 116, 63, 158, 195, 194, 226}, - }, - { - name: "SingleASHA256SeparateFile", - data: []byte{'a'}, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedHash: []byte{26, 47, 238, 138, 235, 244, 140, 231, 129, 240, 155, 252, 219, 44, 46, 72, 57, 249, 139, 88, 132, 238, 86, 108, 181, 115, 96, 72, 99, 210, 134, 47}, - }, - { - name: "SingleASHA256SameFile", - data: []byte{'a'}, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedHash: []byte{26, 47, 238, 138, 235, 244, 140, 231, 129, 240, 155, 252, 219, 44, 46, 72, 57, 249, 139, 88, 132, 238, 86, 108, 181, 115, 96, 72, 99, 210, 134, 47}, - }, - { - name: "SingleASHA512SeparateFile", - data: []byte{'a'}, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedHash: []byte{44, 30, 224, 12, 102, 119, 163, 171, 119, 175, 212, 121, 231, 188, 125, 171, 79, 28, 144, 234, 75, 122, 44, 75, 15, 101, 173, 92, 233, 109, 234, 60, 173, 148, 125, 85, 94, 234, 95, 91, 16, 196, 88, 175, 23, 129, 226, 110, 24, 238, 5, 49, 186, 128, 72, 188, 193, 180, 207, 193, 203, 119, 40, 191}, - }, - { - name: "SingleASHA512SameFile", - data: []byte{'a'}, - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedHash: []byte{44, 30, 224, 12, 102, 119, 163, 171, 119, 175, 212, 121, 231, 188, 125, 171, 79, 28, 144, 234, 75, 122, 44, 75, 15, 101, 173, 92, 233, 109, 234, 60, 173, 148, 125, 85, 94, 234, 95, 91, 16, 196, 88, 175, 23, 129, 226, 110, 24, 238, 5, 49, 186, 128, 72, 188, 193, 180, 207, 193, 203, 119, 40, 191}, - }, - { - name: "OnePageASHA256SeparateFile", - data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - expectedHash: []byte{166, 254, 83, 46, 241, 111, 18, 47, 79, 6, 181, 197, 176, 143, 211, 204, 53, 5, 245, 134, 172, 95, 97, 131, 236, 132, 197, 138, 123, 78, 43, 13}, - }, - { - name: "OnePageASHA256SameFile", - data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - expectedHash: []byte{166, 254, 83, 46, 241, 111, 18, 47, 79, 6, 181, 197, 176, 143, 211, 204, 53, 5, 245, 134, 172, 95, 97, 131, 236, 132, 197, 138, 123, 78, 43, 13}, - }, - { - name: "OnePageASHA512SeparateFile", - data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - expectedHash: []byte{23, 69, 6, 79, 39, 232, 90, 246, 62, 55, 4, 229, 47, 36, 230, 24, 233, 47, 55, 36, 26, 139, 196, 78, 242, 12, 194, 77, 109, 81, 151, 188, 63, 201, 127, 235, 81, 214, 91, 200, 19, 232, 240, 14, 197, 1, 99, 224, 18, 213, 203, 242, 44, 102, 25, 62, 90, 189, 106, 107, 129, 61, 115, 39}, - }, - { - name: "OnePageASHA512SameFile", - data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), - hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - expectedHash: []byte{23, 69, 6, 79, 39, 232, 90, 246, 62, 55, 4, 229, 47, 36, 230, 24, 233, 47, 55, 36, 26, 139, 196, 78, 242, 12, 194, 77, 109, 81, 151, 188, 63, 201, 127, 235, 81, 214, 91, 200, 19, 232, 240, 14, 197, 1, 99, 224, 18, 213, 203, 242, 44, 102, 25, 62, 90, 189, 106, 107, 129, 61, 115, 39}, - }, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { - var tree bytesReadWriter - params := GenerateParams{ - Size: int64(len(tc.data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - Children: []string{}, - HashAlgorithms: tc.hashAlgorithms, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: tc.dataAndTreeInSameFile, - } - if tc.dataAndTreeInSameFile { - tree.Write(tc.data) - params.File = &tree - } else { - params.File = &bytesReadWriter{ - bytes: tc.data, - } - } - hash, err := Generate(¶ms) - if err != nil { - t.Fatalf("Got err: %v, want nil", err) - } - if !bytes.Equal(hash, tc.expectedHash) { - t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash) - } - }) - } -} - -// prepareVerify generates test data and corresponding Merkle tree, and returns -// the prepared VerifyParams. -// The test data has size dataSize. The data is hashed with hashAlgorithm. The -// portion to be verified is the range [verifyStart, verifyStart + verifySize). -func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeInSameFile, isSymlink bool, verifyStart, verifySize int64, out io.Writer) ([]byte, VerifyParams) { - t.Helper() - data := make([]byte, dataSize) - // Generate random bytes in data. - rand.Read(data) - - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - Children: []string{}, - HashAlgorithms: hashAlgorithm, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, - } - } - - if isSymlink { - genParams.SymlinkTarget = defaultSymlinkPath - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("could not generate Merkle tree:%v", err) - } - - return data, VerifyParams{ - Out: out, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - Children: []string{}, - HashAlgorithms: hashAlgorithm, - ReadOffset: verifyStart, - ReadSize: verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } -} - -func TestVerifyInvalidRange(t *testing.T) { - testCases := []struct { - name string - verifyStart int64 - verifySize int64 - }{ - // Verify range starts outside data range. - { - name: "StartOutsideRange", - verifyStart: hostarch.PageSize, - verifySize: 1, - }, - // Verify range ends outside data range. - { - name: "EndOutsideRange", - verifyStart: 0, - verifySize: 2 * hostarch.PageSize, - }, - // Verify range with negative size. - { - name: "NegativeSize", - verifyStart: 1, - verifySize: -1, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf) - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyUnmodifiedMetadata(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - isSymlink bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - isSymlink: true, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - isSymlink: false, - }, - { - name: "SameFileSymlink", - dataAndTreeInSameFile: true, - isSymlink: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, tc.isSymlink, 0 /* verifyStart */, 0 /* verifySize */, &buf) - if tc.isSymlink { - params.SymlinkTarget = defaultSymlinkPath - } - if _, err := Verify(¶ms); !errors.Is(err, nil) { - t.Errorf("Verification failed when expected to succeed: %v", err) - } - }) - } -} - -func TestVerifyModifiedName(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.Name += "abc" - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyModifiedSize(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.Size-- - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyModifiedMode(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.Mode++ - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyModifiedUID(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.UID++ - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyModifiedGID(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.GID++ - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyModifiedChildren(t *testing.T) { - testCases := []struct { - name string - dataAndTreeInSameFile bool - }{ - { - name: "SeparateFile", - dataAndTreeInSameFile: false, - }, - { - name: "SameFile", - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.Children = append(params.Children, "abc") - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyModifiedSymlink(t *testing.T) { - var buf bytes.Buffer - _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, true /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.SymlinkTarget = "merkle_modified_test_link" - if _, err := Verify(¶ms); err == nil { - t.Errorf("Verification succeeded when expected to fail") - } -} - -func TestModifyOutsideVerifyRange(t *testing.T) { - testCases := []struct { - name string - // The byte with index modifyByte is modified. - modifyByte int64 - dataAndTreeInSameFile bool - }{ - { - name: "BeforeRangeSeparateFile", - modifyByte: 4*hostarch.PageSize - 1, - dataAndTreeInSameFile: false, - }, - { - name: "BeforeRangeSameFile", - modifyByte: 4*hostarch.PageSize - 1, - dataAndTreeInSameFile: true, - }, - { - name: "AfterRangeSeparateFile", - modifyByte: 5 * hostarch.PageSize, - dataAndTreeInSameFile: false, - }, - { - name: "AfterRangeSameFile", - modifyByte: 5 * hostarch.PageSize, - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - dataSize := int64(8 * hostarch.PageSize) - verifyStart := int64(4 * hostarch.PageSize) - verifySize := int64(hostarch.PageSize) - var buf bytes.Buffer - // Modified byte is outside verify range. Verify should succeed. - data, params := prepareVerify(t, dataSize, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, verifyStart, verifySize, &buf) - // Flip a bit in data and checks Verify results. - data[tc.modifyByte] ^= 1 - n, err := Verify(¶ms) - if !errors.Is(err, nil) { - t.Errorf("Verification failed when expected to succeed: %v", err) - } - if n != verifySize { - t.Errorf("Got Verify output size %d, want %d", n, verifySize) - } - if int64(buf.Len()) != verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), verifySize) - } - if !bytes.Equal(data[verifyStart:verifyStart+verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } - }) - } -} - -func TestModifyInsideVerifyRange(t *testing.T) { - testCases := []struct { - name string - verifyStart int64 - verifySize int64 - // The byte with index modifyByte is modified. - modifyByte int64 - dataAndTreeInSameFile bool - }{ - // Test a block-aligned verify range. - // Modifying a byte in the verified range should cause verify - // to fail. - { - name: "BlockAlignedRangeSeparateFile", - verifyStart: 4 * hostarch.PageSize, - verifySize: hostarch.PageSize, - modifyByte: 4 * hostarch.PageSize, - dataAndTreeInSameFile: false, - }, - { - name: "BlockAlignedRangeSameFile", - verifyStart: 4 * hostarch.PageSize, - verifySize: hostarch.PageSize, - modifyByte: 4 * hostarch.PageSize, - dataAndTreeInSameFile: true, - }, - // The tests below use a non-block-aligned verify range. - // Modifying a byte at strat of verify range should cause - // verify to fail. - { - name: "ModifyStartSeparateFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 4*hostarch.PageSize + 123, - dataAndTreeInSameFile: false, - }, - { - name: "ModifyStartSameFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 4*hostarch.PageSize + 123, - dataAndTreeInSameFile: true, - }, - // Modifying a byte at the end of verify range should cause - // verify to fail. - { - name: "ModifyEndSeparateFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 6*hostarch.PageSize + 123, - dataAndTreeInSameFile: false, - }, - { - name: "ModifyEndSameFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 6*hostarch.PageSize + 123, - dataAndTreeInSameFile: true, - }, - // Modifying a byte in the middle verified block should cause - // verify to fail. - { - name: "ModifyMiddleSeparateFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 5*hostarch.PageSize + 123, - dataAndTreeInSameFile: false, - }, - { - name: "ModifyMiddleSameFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 5*hostarch.PageSize + 123, - dataAndTreeInSameFile: true, - }, - // Modifying a byte in the first block in the verified range - // should cause verify to fail, even the modified bit itself is - // out of verify range. - { - name: "ModifyFirstBlockSeparateFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 4*hostarch.PageSize + 122, - dataAndTreeInSameFile: false, - }, - { - name: "ModifyFirstBlockSameFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 4*hostarch.PageSize + 122, - dataAndTreeInSameFile: true, - }, - // Modifying a byte in the last block in the verified range - // should cause verify to fail, even the modified bit itself is - // out of verify range. - { - name: "ModifyLastBlockSeparateFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 6*hostarch.PageSize + 124, - dataAndTreeInSameFile: false, - }, - { - name: "ModifyLastBlockSameFile", - verifyStart: 4*hostarch.PageSize + 123, - verifySize: 2 * hostarch.PageSize, - modifyByte: 6*hostarch.PageSize + 124, - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - dataSize := int64(8 * hostarch.PageSize) - var buf bytes.Buffer - data, params := prepareVerify(t, dataSize, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf) - // Flip a bit in data and checks Verify results. - data[tc.modifyByte] ^= 1 - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Errorf("Verification succeeded when expected to fail") - } - }) - } -} - -func TestVerifyRandom(t *testing.T) { - testCases := []struct { - name string - hashAlgorithm int - dataAndTreeInSameFile bool - }{ - { - name: "SHA256SeparateFile", - hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: false, - }, - { - name: "SHA512SeparateFile", - hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: false, - }, - { - name: "SHA256SameFile", - hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, - dataAndTreeInSameFile: true, - }, - { - name: "SHA512SameFile", - hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, - dataAndTreeInSameFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - // Use a random dataSize. Minimum size 2 so that we can pick a random - // portion from it. - dataSize := rand.Int63n(200*hostarch.PageSize) + 2 - - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 - - var buf bytes.Buffer - data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, start, size, &buf) - - // Checks that the random portion of data from the original data is - // verified successfully. - n, err := Verify(¶ms) - if err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if n != size { - t.Errorf("Got Verify output size %d, want %d", n, size) - } - if int64(buf.Len()) != size { - t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) - } - if !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } - - // Verify that modified metadata should fail verification. - buf.Reset() - params.Name = defaultName + "abc" - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Error("Verify succeeded for modified metadata, expect failure") - } - - // Flip a random bit in randPortion, and check that verification fails. - buf.Reset() - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 - params.File = bytes.NewReader(data) - params.Name = defaultName - - if _, err := Verify(¶ms); errors.Is(err, nil) { - t.Error("Verification succeeded for modified data, expect failure") - } - }) - } -} diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD deleted file mode 100644 index c08792751..000000000 --- a/pkg/metric/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "metric", - srcs = [ - "metric.go", - ], - visibility = ["//:sandbox"], - deps = [ - ":metric_go_proto", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/sync", - "@org_golang_google_protobuf//types/known/timestamppb", - ], -) - -proto_library( - name = "metric", - srcs = ["metric.proto"], - visibility = ["//:sandbox"], - deps = [ - "@com_google_protobuf//:timestamp_proto", - ], -) - -go_test( - name = "metric_test", - srcs = ["metric_test.go"], - library = ":metric", - deps = [ - ":metric_go_proto", - "//pkg/eventchannel", - "@org_golang_google_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto deleted file mode 100644 index d466b6904..000000000 --- a/pkg/metric/metric.proto +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "google/protobuf/timestamp.proto"; - -// MetricMetadata contains all of the metadata describing a single metric. -message MetricMetadata { - // name is the unique name of the metric, usually in a "directory" format - // (e.g., /foo/count). - string name = 1; - - // description is a human-readable description of the metric. - string description = 2; - - // cumulative indicates that this metric is never decremented. - bool cumulative = 3; - - // sync indicates that values from the final metric event should be - // synchronized to the backing monitoring system at exit. - // - // If sync is false, values are only sent to the monitoring system - // periodically. There is no guarantee that values will ever be received by - // the monitoring system. - bool sync = 4; - - enum Type { TYPE_UINT64 = 0; } - - // type is the type of the metric value. - Type type = 5; - - enum Units { - UNITS_NONE = 0; - UNITS_NANOSECONDS = 1; - } - - // units is the units of the metric value. - Units units = 6; - - message Field { - string field_name = 1; - repeated string allowed_values = 2; - } - - // fields contains the metric fields. Currently a metric can have at most - // one field. - repeated Field fields = 7; -} - -// MetricRegistration contains the metadata for all metrics that will be in -// future MetricUpdates. -message MetricRegistration { - repeated MetricMetadata metrics = 1; - repeated string stages = 2; -} - -// MetricValue the value of a metric at a single point in time. -message MetricValue { - // name is the unique name of the metric, as in MetricMetadata. - string name = 1; - - // value is the value of the metric at a single point in time. The field set - // depends on the type of the metric. - oneof value { - uint64 uint64_value = 2; - } - - repeated string field_values = 4; -} - -// StageTiming represents a new stage that's been reached by the Sentry. -message StageTiming { - string stage = 1; - google.protobuf.Timestamp started = 2; - google.protobuf.Timestamp ended = 3; -} - -// MetricUpdate contains new values for multiple distinct metrics. -// -// Metrics whose values have not changed are not included. -message MetricUpdate { - repeated MetricValue metrics = 1; - // Timing information of initialization stages reached since last update. - // The first MetricUpdate will include multiple entries, since metric - // initialization happens relatively late in the Sentry startup process. - repeated StageTiming stage_timing = 2; -} diff --git a/pkg/metric/metric_go_proto/metric.pb.go b/pkg/metric/metric_go_proto/metric.pb.go new file mode 100644 index 000000000..7d327e3a0 --- /dev/null +++ b/pkg/metric/metric_go_proto/metric.pb.go @@ -0,0 +1,732 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/metric/metric.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + timestamp "github.com/golang/protobuf/ptypes/timestamp" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type MetricMetadata_Type int32 + +const ( + MetricMetadata_TYPE_UINT64 MetricMetadata_Type = 0 +) + +// Enum value maps for MetricMetadata_Type. +var ( + MetricMetadata_Type_name = map[int32]string{ + 0: "TYPE_UINT64", + } + MetricMetadata_Type_value = map[string]int32{ + "TYPE_UINT64": 0, + } +) + +func (x MetricMetadata_Type) Enum() *MetricMetadata_Type { + p := new(MetricMetadata_Type) + *p = x + return p +} + +func (x MetricMetadata_Type) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (MetricMetadata_Type) Descriptor() protoreflect.EnumDescriptor { + return file_pkg_metric_metric_proto_enumTypes[0].Descriptor() +} + +func (MetricMetadata_Type) Type() protoreflect.EnumType { + return &file_pkg_metric_metric_proto_enumTypes[0] +} + +func (x MetricMetadata_Type) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use MetricMetadata_Type.Descriptor instead. +func (MetricMetadata_Type) EnumDescriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{0, 0} +} + +type MetricMetadata_Units int32 + +const ( + MetricMetadata_UNITS_NONE MetricMetadata_Units = 0 + MetricMetadata_UNITS_NANOSECONDS MetricMetadata_Units = 1 +) + +// Enum value maps for MetricMetadata_Units. +var ( + MetricMetadata_Units_name = map[int32]string{ + 0: "UNITS_NONE", + 1: "UNITS_NANOSECONDS", + } + MetricMetadata_Units_value = map[string]int32{ + "UNITS_NONE": 0, + "UNITS_NANOSECONDS": 1, + } +) + +func (x MetricMetadata_Units) Enum() *MetricMetadata_Units { + p := new(MetricMetadata_Units) + *p = x + return p +} + +func (x MetricMetadata_Units) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (MetricMetadata_Units) Descriptor() protoreflect.EnumDescriptor { + return file_pkg_metric_metric_proto_enumTypes[1].Descriptor() +} + +func (MetricMetadata_Units) Type() protoreflect.EnumType { + return &file_pkg_metric_metric_proto_enumTypes[1] +} + +func (x MetricMetadata_Units) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use MetricMetadata_Units.Descriptor instead. +func (MetricMetadata_Units) EnumDescriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{0, 1} +} + +type MetricMetadata struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"` + Cumulative bool `protobuf:"varint,3,opt,name=cumulative,proto3" json:"cumulative,omitempty"` + Sync bool `protobuf:"varint,4,opt,name=sync,proto3" json:"sync,omitempty"` + Type MetricMetadata_Type `protobuf:"varint,5,opt,name=type,proto3,enum=gvisor.MetricMetadata_Type" json:"type,omitempty"` + Units MetricMetadata_Units `protobuf:"varint,6,opt,name=units,proto3,enum=gvisor.MetricMetadata_Units" json:"units,omitempty"` + Fields []*MetricMetadata_Field `protobuf:"bytes,7,rep,name=fields,proto3" json:"fields,omitempty"` +} + +func (x *MetricMetadata) Reset() { + *x = MetricMetadata{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_metric_metric_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricMetadata) ProtoMessage() {} + +func (x *MetricMetadata) ProtoReflect() protoreflect.Message { + mi := &file_pkg_metric_metric_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricMetadata.ProtoReflect.Descriptor instead. +func (*MetricMetadata) Descriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{0} +} + +func (x *MetricMetadata) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *MetricMetadata) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +func (x *MetricMetadata) GetCumulative() bool { + if x != nil { + return x.Cumulative + } + return false +} + +func (x *MetricMetadata) GetSync() bool { + if x != nil { + return x.Sync + } + return false +} + +func (x *MetricMetadata) GetType() MetricMetadata_Type { + if x != nil { + return x.Type + } + return MetricMetadata_TYPE_UINT64 +} + +func (x *MetricMetadata) GetUnits() MetricMetadata_Units { + if x != nil { + return x.Units + } + return MetricMetadata_UNITS_NONE +} + +func (x *MetricMetadata) GetFields() []*MetricMetadata_Field { + if x != nil { + return x.Fields + } + return nil +} + +type MetricRegistration struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Metrics []*MetricMetadata `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"` + Stages []string `protobuf:"bytes,2,rep,name=stages,proto3" json:"stages,omitempty"` +} + +func (x *MetricRegistration) Reset() { + *x = MetricRegistration{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_metric_metric_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricRegistration) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricRegistration) ProtoMessage() {} + +func (x *MetricRegistration) ProtoReflect() protoreflect.Message { + mi := &file_pkg_metric_metric_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricRegistration.ProtoReflect.Descriptor instead. +func (*MetricRegistration) Descriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{1} +} + +func (x *MetricRegistration) GetMetrics() []*MetricMetadata { + if x != nil { + return x.Metrics + } + return nil +} + +func (x *MetricRegistration) GetStages() []string { + if x != nil { + return x.Stages + } + return nil +} + +type MetricValue struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // Types that are assignable to Value: + // *MetricValue_Uint64Value + Value isMetricValue_Value `protobuf_oneof:"value"` + FieldValues []string `protobuf:"bytes,4,rep,name=field_values,json=fieldValues,proto3" json:"field_values,omitempty"` +} + +func (x *MetricValue) Reset() { + *x = MetricValue{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_metric_metric_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricValue) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricValue) ProtoMessage() {} + +func (x *MetricValue) ProtoReflect() protoreflect.Message { + mi := &file_pkg_metric_metric_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricValue.ProtoReflect.Descriptor instead. +func (*MetricValue) Descriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{2} +} + +func (x *MetricValue) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (m *MetricValue) GetValue() isMetricValue_Value { + if m != nil { + return m.Value + } + return nil +} + +func (x *MetricValue) GetUint64Value() uint64 { + if x, ok := x.GetValue().(*MetricValue_Uint64Value); ok { + return x.Uint64Value + } + return 0 +} + +func (x *MetricValue) GetFieldValues() []string { + if x != nil { + return x.FieldValues + } + return nil +} + +type isMetricValue_Value interface { + isMetricValue_Value() +} + +type MetricValue_Uint64Value struct { + Uint64Value uint64 `protobuf:"varint,2,opt,name=uint64_value,json=uint64Value,proto3,oneof"` +} + +func (*MetricValue_Uint64Value) isMetricValue_Value() {} + +type StageTiming struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Stage string `protobuf:"bytes,1,opt,name=stage,proto3" json:"stage,omitempty"` + Started *timestamp.Timestamp `protobuf:"bytes,2,opt,name=started,proto3" json:"started,omitempty"` + Ended *timestamp.Timestamp `protobuf:"bytes,3,opt,name=ended,proto3" json:"ended,omitempty"` +} + +func (x *StageTiming) Reset() { + *x = StageTiming{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_metric_metric_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StageTiming) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StageTiming) ProtoMessage() {} + +func (x *StageTiming) ProtoReflect() protoreflect.Message { + mi := &file_pkg_metric_metric_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StageTiming.ProtoReflect.Descriptor instead. +func (*StageTiming) Descriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{3} +} + +func (x *StageTiming) GetStage() string { + if x != nil { + return x.Stage + } + return "" +} + +func (x *StageTiming) GetStarted() *timestamp.Timestamp { + if x != nil { + return x.Started + } + return nil +} + +func (x *StageTiming) GetEnded() *timestamp.Timestamp { + if x != nil { + return x.Ended + } + return nil +} + +type MetricUpdate struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Metrics []*MetricValue `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"` + StageTiming []*StageTiming `protobuf:"bytes,2,rep,name=stage_timing,json=stageTiming,proto3" json:"stage_timing,omitempty"` +} + +func (x *MetricUpdate) Reset() { + *x = MetricUpdate{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_metric_metric_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricUpdate) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricUpdate) ProtoMessage() {} + +func (x *MetricUpdate) ProtoReflect() protoreflect.Message { + mi := &file_pkg_metric_metric_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricUpdate.ProtoReflect.Descriptor instead. +func (*MetricUpdate) Descriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{4} +} + +func (x *MetricUpdate) GetMetrics() []*MetricValue { + if x != nil { + return x.Metrics + } + return nil +} + +func (x *MetricUpdate) GetStageTiming() []*StageTiming { + if x != nil { + return x.StageTiming + } + return nil +} + +type MetricMetadata_Field struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + FieldName string `protobuf:"bytes,1,opt,name=field_name,json=fieldName,proto3" json:"field_name,omitempty"` + AllowedValues []string `protobuf:"bytes,2,rep,name=allowed_values,json=allowedValues,proto3" json:"allowed_values,omitempty"` +} + +func (x *MetricMetadata_Field) Reset() { + *x = MetricMetadata_Field{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_metric_metric_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricMetadata_Field) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricMetadata_Field) ProtoMessage() {} + +func (x *MetricMetadata_Field) ProtoReflect() protoreflect.Message { + mi := &file_pkg_metric_metric_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricMetadata_Field.ProtoReflect.Descriptor instead. +func (*MetricMetadata_Field) Descriptor() ([]byte, []int) { + return file_pkg_metric_metric_proto_rawDescGZIP(), []int{0, 0} +} + +func (x *MetricMetadata_Field) GetFieldName() string { + if x != nil { + return x.FieldName + } + return "" +} + +func (x *MetricMetadata_Field) GetAllowedValues() []string { + if x != nil { + return x.AllowedValues + } + return nil +} + +var File_pkg_metric_metric_proto protoreflect.FileDescriptor + +var file_pkg_metric_metric_proto_rawDesc = []byte{ + 0x0a, 0x17, 0x70, 0x6b, 0x67, 0x2f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x2f, 0x6d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x67, 0x76, 0x69, 0x73, 0x6f, + 0x72, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x22, 0xad, 0x03, 0x0a, 0x0e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, + 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, + 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1e, 0x0a, 0x0a, 0x63, + 0x75, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0a, 0x63, 0x75, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x76, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x73, + 0x79, 0x6e, 0x63, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x73, 0x79, 0x6e, 0x63, 0x12, + 0x2f, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, + 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, + 0x12, 0x32, 0x0a, 0x05, 0x75, 0x6e, 0x69, 0x74, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x1c, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x55, 0x6e, 0x69, 0x74, 0x73, 0x52, 0x05, 0x75, + 0x6e, 0x69, 0x74, 0x73, 0x12, 0x34, 0x0a, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, 0x07, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x46, 0x69, 0x65, + 0x6c, 0x64, 0x52, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x1a, 0x4d, 0x0a, 0x05, 0x46, 0x69, + 0x65, 0x6c, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x6c, 0x6c, 0x6f, + 0x77, 0x65, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x17, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x49, 0x4e, 0x54, 0x36, 0x34, + 0x10, 0x00, 0x22, 0x2e, 0x0a, 0x05, 0x55, 0x6e, 0x69, 0x74, 0x73, 0x12, 0x0e, 0x0a, 0x0a, 0x55, + 0x4e, 0x49, 0x54, 0x53, 0x5f, 0x4e, 0x4f, 0x4e, 0x45, 0x10, 0x00, 0x12, 0x15, 0x0a, 0x11, 0x55, + 0x4e, 0x49, 0x54, 0x53, 0x5f, 0x4e, 0x41, 0x4e, 0x4f, 0x53, 0x45, 0x43, 0x4f, 0x4e, 0x44, 0x53, + 0x10, 0x01, 0x22, 0x5e, 0x0a, 0x12, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x52, 0x65, 0x67, 0x69, + 0x73, 0x74, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x30, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x76, 0x69, 0x73, + 0x6f, 0x72, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, + 0x61, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x67, + 0x65, 0x73, 0x22, 0x72, 0x0a, 0x0b, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x56, 0x61, 0x6c, 0x75, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x23, 0x0a, 0x0c, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x48, 0x00, 0x52, 0x0b, 0x75, + 0x69, 0x6e, 0x74, 0x36, 0x34, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x66, 0x69, + 0x65, 0x6c, 0x64, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0b, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x42, 0x07, 0x0a, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x0b, 0x53, 0x74, 0x61, 0x67, 0x65, + 0x54, 0x69, 0x6d, 0x69, 0x6e, 0x67, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x67, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x73, 0x74, 0x61, 0x67, 0x65, 0x12, 0x34, 0x0a, 0x07, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x07, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x65, 0x64, 0x12, 0x30, 0x0a, 0x05, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x05, 0x65, + 0x6e, 0x64, 0x65, 0x64, 0x22, 0x75, 0x0a, 0x0c, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x12, 0x2d, 0x0a, 0x07, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x07, 0x6d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x73, 0x12, 0x36, 0x0a, 0x0c, 0x73, 0x74, 0x61, 0x67, 0x65, 0x5f, 0x74, 0x69, 0x6d, + 0x69, 0x6e, 0x67, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x67, 0x76, 0x69, 0x73, + 0x6f, 0x72, 0x2e, 0x53, 0x74, 0x61, 0x67, 0x65, 0x54, 0x69, 0x6d, 0x69, 0x6e, 0x67, 0x52, 0x0b, + 0x73, 0x74, 0x61, 0x67, 0x65, 0x54, 0x69, 0x6d, 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_metric_metric_proto_rawDescOnce sync.Once + file_pkg_metric_metric_proto_rawDescData = file_pkg_metric_metric_proto_rawDesc +) + +func file_pkg_metric_metric_proto_rawDescGZIP() []byte { + file_pkg_metric_metric_proto_rawDescOnce.Do(func() { + file_pkg_metric_metric_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_metric_metric_proto_rawDescData) + }) + return file_pkg_metric_metric_proto_rawDescData +} + +var file_pkg_metric_metric_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_pkg_metric_metric_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_pkg_metric_metric_proto_goTypes = []interface{}{ + (MetricMetadata_Type)(0), // 0: gvisor.MetricMetadata.Type + (MetricMetadata_Units)(0), // 1: gvisor.MetricMetadata.Units + (*MetricMetadata)(nil), // 2: gvisor.MetricMetadata + (*MetricRegistration)(nil), // 3: gvisor.MetricRegistration + (*MetricValue)(nil), // 4: gvisor.MetricValue + (*StageTiming)(nil), // 5: gvisor.StageTiming + (*MetricUpdate)(nil), // 6: gvisor.MetricUpdate + (*MetricMetadata_Field)(nil), // 7: gvisor.MetricMetadata.Field + (*timestamp.Timestamp)(nil), // 8: google.protobuf.Timestamp +} +var file_pkg_metric_metric_proto_depIdxs = []int32{ + 0, // 0: gvisor.MetricMetadata.type:type_name -> gvisor.MetricMetadata.Type + 1, // 1: gvisor.MetricMetadata.units:type_name -> gvisor.MetricMetadata.Units + 7, // 2: gvisor.MetricMetadata.fields:type_name -> gvisor.MetricMetadata.Field + 2, // 3: gvisor.MetricRegistration.metrics:type_name -> gvisor.MetricMetadata + 8, // 4: gvisor.StageTiming.started:type_name -> google.protobuf.Timestamp + 8, // 5: gvisor.StageTiming.ended:type_name -> google.protobuf.Timestamp + 4, // 6: gvisor.MetricUpdate.metrics:type_name -> gvisor.MetricValue + 5, // 7: gvisor.MetricUpdate.stage_timing:type_name -> gvisor.StageTiming + 8, // [8:8] is the sub-list for method output_type + 8, // [8:8] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name +} + +func init() { file_pkg_metric_metric_proto_init() } +func file_pkg_metric_metric_proto_init() { + if File_pkg_metric_metric_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_metric_metric_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MetricMetadata); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_metric_metric_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MetricRegistration); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_metric_metric_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MetricValue); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_metric_metric_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StageTiming); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_metric_metric_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MetricUpdate); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_metric_metric_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MetricMetadata_Field); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pkg_metric_metric_proto_msgTypes[2].OneofWrappers = []interface{}{ + (*MetricValue_Uint64Value)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_metric_metric_proto_rawDesc, + NumEnums: 2, + NumMessages: 6, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_metric_metric_proto_goTypes, + DependencyIndexes: file_pkg_metric_metric_proto_depIdxs, + EnumInfos: file_pkg_metric_metric_proto_enumTypes, + MessageInfos: file_pkg_metric_metric_proto_msgTypes, + }.Build() + File_pkg_metric_metric_proto = out.File + file_pkg_metric_metric_proto_rawDesc = nil + file_pkg_metric_metric_proto_goTypes = nil + file_pkg_metric_metric_proto_depIdxs = nil +} diff --git a/pkg/metric/metric_state_autogen.go b/pkg/metric/metric_state_autogen.go new file mode 100644 index 000000000..36e5ed81b --- /dev/null +++ b/pkg/metric/metric_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package metric diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go deleted file mode 100644 index 0654bdf07..000000000 --- a/pkg/metric/metric_test.go +++ /dev/null @@ -1,499 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "testing" - "time" - - "google.golang.org/protobuf/proto" - "gvisor.dev/gvisor/pkg/eventchannel" - pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" -) - -// sliceEmitter implements eventchannel.Emitter by appending all messages to a -// slice. -type sliceEmitter []proto.Message - -// Emit implements eventchannel.Emitter.Emit. -func (s *sliceEmitter) Emit(msg proto.Message) (bool, error) { - *s = append(*s, msg) - return false, nil -} - -// Emit implements eventchannel.Emitter.Close. -func (s *sliceEmitter) Close() error { - return nil -} - -// Reset clears all events in s. -func (s *sliceEmitter) Reset() { - *s = nil -} - -// emitter is the eventchannel.Emitter used for all tests. Package eventchannel -// doesn't allow removing Emitters, so we must use one global emitter for all -// test cases. -var emitter sliceEmitter - -func init() { - reset() - - eventchannel.AddEmitter(&emitter) -} - -// reset clears all global state in the metric package. -func reset() { - initialized = false - allMetrics = makeMetricSet() - emitter.Reset() -} - -const ( - fooDescription = "Foo!" - barDescription = "Bar Baz" - counterDescription = "Counter" -) - -func TestInitialize(t *testing.T) { - defer reset() - - _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NANOSECONDS, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - if err := Initialize(); err != nil { - t.Fatalf("Initialize(): %s", err) - } - - if len(emitter) != 1 { - t.Fatalf("Initialize emitted %d events want 1", len(emitter)) - } - - mr, ok := emitter[0].(*pb.MetricRegistration) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0]) - } - - if len(mr.Metrics) != 2 { - t.Errorf("MetricRegistration got %d metrics want 2", len(mr.Metrics)) - } - - foundFoo := false - foundBar := false - for _, m := range mr.Metrics { - if m.Type != pb.MetricMetadata_TYPE_UINT64 { - t.Errorf("Metadata %+v Type got %v want pb.MetricMetadata_TYPE_UINT64", m, m.Type) - } - if !m.Cumulative { - t.Errorf("Metadata %+v Cumulative got false want true", m) - } - - switch m.Name { - case "/foo": - foundFoo = true - if m.Description != fooDescription { - t.Errorf("/foo %+v Description got %q want %q", m, m.Description, fooDescription) - } - if m.Sync { - t.Errorf("/foo %+v Sync got true want false", m) - } - if m.Units != pb.MetricMetadata_UNITS_NONE { - t.Errorf("/foo %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NONE) - } - case "/bar": - foundBar = true - if m.Description != barDescription { - t.Errorf("/bar %+v Description got %q want %q", m, m.Description, barDescription) - } - if !m.Sync { - t.Errorf("/bar %+v Sync got true want false", m) - } - if m.Units != pb.MetricMetadata_UNITS_NANOSECONDS { - t.Errorf("/bar %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NANOSECONDS) - } - } - } - - if !foundFoo { - t.Errorf("/foo not found: %+v", emitter) - } - if !foundBar { - t.Errorf("/bar not found: %+v", emitter) - } -} - -func TestDisable(t *testing.T) { - defer reset() - - _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - if err := Disable(); err != nil { - t.Fatalf("Disable(): %s", err) - } - - if len(emitter) != 1 { - t.Fatalf("Initialize emitted %d events want 1", len(emitter)) - } - - mr, ok := emitter[0].(*pb.MetricRegistration) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0]) - } - - if len(mr.Metrics) != 0 { - t.Errorf("MetricRegistration got %d metrics want 0", len(mr.Metrics)) - } -} - -func TestEmitMetricUpdate(t *testing.T) { - defer reset() - - foo, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - if err := Initialize(); err != nil { - t.Fatalf("Initialize(): %s", err) - } - - // Don't care about the registration metrics. - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok := emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 2 { - t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics)) - } - - // Both are included for their initial values. - foundFoo := false - foundBar := false - for _, m := range update.Metrics { - switch m.Name { - case "/foo": - foundFoo = true - case "/bar": - foundBar = true - } - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - continue - } - if uv.Uint64Value != 0 { - t.Errorf("%v: Value got %v want 0", m, uv.Uint64Value) - } - } - - if !foundFoo { - t.Errorf("/foo not found: %+v", emitter) - } - if !foundBar { - t.Errorf("/bar not found: %+v", emitter) - } - - // Increment foo. Only it is included in the next update. - foo.Increment() - - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok = emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 1 { - t.Errorf("MetricUpdate got %d metrics want 1", len(update.Metrics)) - } - - m := update.Metrics[0] - - if m.Name != "/foo" { - t.Errorf("Metric %+v name got %q want '/foo'", m, m.Name) - } - - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - } - if uv.Uint64Value != 1 { - t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) - } -} - -func TestEmitMetricUpdateWithFields(t *testing.T) { - defer reset() - - field := Field{ - name: "weirdness_type", - allowedValues: []string{"weird1", "weird2"}} - - counter, err := NewUint64Metric("/weirdness", false, pb.MetricMetadata_UNITS_NONE, counterDescription, field) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - if err := Initialize(); err != nil { - t.Fatalf("Initialize(): %s", err) - } - - // Don't care about the registration metrics. - emitter.Reset() - EmitMetricUpdate() - - // For metrics with fields, we do not emit data unless the value is - // incremented. - if len(emitter) != 0 { - t.Fatalf("EmitMetricUpdate emitted %d events want 0", len(emitter)) - } - - counter.IncrementBy(4, "weird1") - counter.Increment("weird2") - - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok := emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 2 { - t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics)) - } - - foundWeird1 := false - foundWeird2 := false - for i := 0; i < len(update.Metrics); i++ { - m := update.Metrics[i] - - if m.Name != "/weirdness" { - t.Errorf("Metric %+v name got %q want '/weirdness'", m, m.Name) - } - if len(m.FieldValues) != 1 { - t.Errorf("MetricUpdate got %d fields want 1", len(m.FieldValues)) - } - - switch m.FieldValues[0] { - case "weird1": - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - } - if uv.Uint64Value != 4 { - t.Errorf("%v: Value got %v want 4", m, uv.Uint64Value) - } - foundWeird1 = true - case "weird2": - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - } - if uv.Uint64Value != 1 { - t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) - } - foundWeird2 = true - } - } - - if !foundWeird1 { - t.Errorf("Field value weird1 not found: %+v", emitter) - } - if !foundWeird2 { - t.Errorf("Field value weird2 not found: %+v", emitter) - } -} - -func TestMetricUpdateStageTiming(t *testing.T) { - defer reset() - - expectedTimings := map[InitStage]struct{ min, max time.Duration }{} - measureStage := func(stage InitStage, body func()) { - stageStarted := time.Now() - endStage := StartStage(stage) - bodyStarted := time.Now() - body() - bodyEnded := time.Now() - endStage() - stageEnded := time.Now() - - expectedTimings[stage] = struct{ min, max time.Duration }{ - min: bodyEnded.Sub(bodyStarted), - max: stageEnded.Sub(stageStarted), - } - } - checkStage := func(got *pb.StageTiming, want InitStage) { - if InitStage(got.GetStage()) != want { - t.Errorf("%v: got stage %q expected %q", got, got.GetStage(), want) - } - timingBounds, found := expectedTimings[want] - if !found { - t.Fatalf("invalid init stage name %q", want) - } - started := got.Started.AsTime() - ended := got.Ended.AsTime() - duration := ended.Sub(started) - if duration < timingBounds.min { - t.Errorf("stage %v: lasted %v, expected at least %v", want, duration, timingBounds.min) - } else if duration > timingBounds.max { - t.Errorf("stage %v: lasted %v, expected no more than %v", want, duration, timingBounds.max) - } - } - - // Test that it's legit to go through stages before metric registration. - measureStage("before_first_update_1", func() { - time.Sleep(100 * time.Millisecond) - }) - measureStage("before_first_update_2", func() { - time.Sleep(100 * time.Millisecond) - }) - - fooMetric, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) - if err != nil { - t.Fatalf("Cannot register /foo: %v", err) - } - emitter.Reset() - Initialize() - EmitMetricUpdate() - - // We should have gotten the metric registration and the first MetricUpdate. - if len(emitter) != 2 { - t.Fatalf("emitter has %d messages (%v), expected %d", len(emitter), emitter, 2) - } - - if registration, ok := emitter[0].(*pb.MetricRegistration); !ok { - t.Errorf("first message is not MetricRegistration: %T / %v", emitter[0], emitter[0]) - } else if len(registration.Stages) != len(allStages) { - t.Errorf("MetricRegistration has %d stages (%v), expected %d (%v)", len(registration.Stages), registration.Stages, len(allStages), allStages) - } else { - for i := 0; i < len(allStages); i++ { - if InitStage(registration.Stages[i]) != allStages[i] { - t.Errorf("MetricRegistration.Stages[%d]: got %q want %q", i, registration.Stages[i], allStages[i]) - } - } - } - - if firstUpdate, ok := emitter[1].(*pb.MetricUpdate); !ok { - t.Errorf("second message is not MetricUpdate: %T / %v", emitter[1], emitter[1]) - } else if len(firstUpdate.StageTiming) != 2 { - t.Errorf("MetricUpdate has %d stage timings (%v), expected %d", len(firstUpdate.StageTiming), firstUpdate.StageTiming, 2) - } else { - checkStage(firstUpdate.StageTiming[0], "before_first_update_1") - checkStage(firstUpdate.StageTiming[1], "before_first_update_2") - } - - // Ensure re-emitting doesn't cause another event to be sent. - emitter.Reset() - EmitMetricUpdate() - if len(emitter) != 0 { - t.Fatalf("EmitMetricUpdate emitted %d events want %d", len(emitter), 0) - } - - // Generate monitoring data, we should get an event with no stages. - fooMetric.Increment() - emitter.Reset() - EmitMetricUpdate() - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want %d", len(emitter), 1) - } else if update, ok := emitter[0].(*pb.MetricUpdate); !ok { - t.Errorf("message is not MetricUpdate: %T / %v", emitter[1], emitter[1]) - } else if len(update.StageTiming) != 0 { - t.Errorf("unexpected stage timing information: %v", update.StageTiming) - } - - // Now generate new stages. - measureStage("foo_stage_1", func() { - time.Sleep(100 * time.Millisecond) - }) - measureStage("foo_stage_2", func() { - time.Sleep(100 * time.Millisecond) - }) - emitter.Reset() - EmitMetricUpdate() - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want %d", len(emitter), 1) - } else if update, ok := emitter[0].(*pb.MetricUpdate); !ok { - t.Errorf("message is not MetricUpdate: %T / %v", emitter[1], emitter[1]) - } else if len(update.Metrics) != 0 { - t.Errorf("MetricUpdate has %d metric value changes (%v), expected %d", len(update.Metrics), update.Metrics, 0) - } else if len(update.StageTiming) != 2 { - t.Errorf("MetricUpdate has %d stages (%v), expected %d", len(update.StageTiming), update.StageTiming, 2) - } else { - checkStage(update.StageTiming[0], "foo_stage_1") - checkStage(update.StageTiming[1], "foo_stage_2") - } - - // Now try generating data for both metrics and stages. - fooMetric.Increment() - measureStage("last_stage_1", func() { - time.Sleep(100 * time.Millisecond) - }) - measureStage("last_stage_2", func() { - time.Sleep(100 * time.Millisecond) - }) - fooMetric.Increment() - emitter.Reset() - EmitMetricUpdate() - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want %d", len(emitter), 1) - } else if update, ok := emitter[0].(*pb.MetricUpdate); !ok { - t.Errorf("message is not MetricUpdate: %T / %v", emitter[1], emitter[1]) - } else if len(update.Metrics) != 1 { - t.Errorf("MetricUpdate has %d metric value changes (%v), expected %d", len(update.Metrics), update.Metrics, 1) - } else if len(update.StageTiming) != 2 { - t.Errorf("MetricUpdate has %d stages (%v), expected %d", len(update.StageTiming), update.StageTiming, 2) - } else { - checkStage(update.StageTiming[0], "last_stage_1") - checkStage(update.StageTiming[1], "last_stage_2") - } -} diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD deleted file mode 100644 index 2b22b2203..000000000 --- a/pkg/p9/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -go_library( - name = "p9", - srcs = [ - "buffer.go", - "client.go", - "client_file.go", - "file.go", - "handlers.go", - "messages.go", - "p9.go", - "path_tree.go", - "server.go", - "transport.go", - "transport_flipcall.go", - "version.go", - ], - deps = [ - "//pkg/abi/linux/errno", - "//pkg/errors", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/fdchannel", - "//pkg/flipcall", - "//pkg/log", - "//pkg/pool", - "//pkg/sync", - "//pkg/unet", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "p9_test", - size = "small", - srcs = [ - "buffer_test.go", - "client_test.go", - "messages_test.go", - "p9_test.go", - "transport_test.go", - "version_test.go", - ], - library = ":p9", - deps = [ - "//pkg/fd", - "//pkg/unet", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/p9/buffer_test.go b/pkg/p9/buffer_test.go deleted file mode 100644 index a9c75f86b..000000000 --- a/pkg/p9/buffer_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestBufferOverrun(t *testing.T) { - buf := &buffer{ - // This header indicates that a large string should follow, but - // it is only two bytes. Reading a string should cause an - // overrun. - data: []byte{0x0, 0x16}, - } - if s := buf.ReadString(); s != "" { - t.Errorf("overrun read got %s, want empty", s) - } -} diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go deleted file mode 100644 index 24e0dd7e8..000000000 --- a/pkg/p9/client_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/unet" -) - -// TestVersion tests the version negotiation. -func TestVersion(t *testing.T) { - // First, create a new server and connection. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer clientSocket.Close() - - // Create a new server and client. - s := NewServer(nil) - go s.Handle(serverSocket) - - // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) - if err != nil { - t.Fatalf("got %v, expected nil", err) - } - - // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != unix.EINVAL { - t.Errorf("got %v expected %v", err, unix.EINVAL) - } - - // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != unix.EINVAL { - t.Errorf("got %v expected %v", err, unix.EINVAL) - } - - // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != unix.EAGAIN { - t.Errorf("got %v expected %v", err, unix.EAGAIN) - } - - // Check an invalid MSize. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion), MSize: 0}, &Rversion{}); err != unix.EINVAL { - t.Errorf("got %v expected %v", err, unix.EINVAL) - } -} - -func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { - b.ReportAllocs() - - // See above. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - b.Fatalf("socketpair got err %v expected nil", err) - } - defer clientSocket.Close() - - // See above. - s := NewServer(nil) - go s.Handle(serverSocket) - - // See above. - c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) - if err != nil { - b.Fatalf("got %v, expected nil", err) - } - - // Initialize messages. - sendRecv := fn(c) - tversion := &Tversion{ - Version: versionString(highestSupportedVersion), - MSize: DefaultMessageSize, - } - rversion := new(Rversion) - - // Run in a loop. - for i := 0; i < b.N; i++ { - if err := sendRecv(tversion, rversion); err != nil { - b.Fatalf("got unexpected err: %v", err) - } - } -} - -func BenchmarkSendRecvLegacy(b *testing.B) { - benchmarkSendRecv(b, func(c *Client) func(message, message) error { - return func(t message, r message) error { - _, err := c.sendRecvLegacy(t, r) - return err - } - }) -} - -func BenchmarkSendRecvChannel(b *testing.B) { - benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) -} diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go deleted file mode 100644 index bfeb6c236..000000000 --- a/pkg/p9/messages_test.go +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "fmt" - "reflect" - "testing" -) - -func TestEncodeDecode(t *testing.T) { - objs := []encoder{ - &QID{ - Type: 1, - Version: 2, - Path: 3, - }, - &FSStat{ - Type: 1, - BlockSize: 2, - Blocks: 3, - BlocksFree: 4, - BlocksAvailable: 5, - Files: 6, - FilesFree: 7, - FSID: 8, - NameLength: 9, - }, - &AttrMask{ - Mode: true, - NLink: true, - UID: true, - GID: true, - RDev: true, - ATime: true, - MTime: true, - CTime: true, - INo: true, - Size: true, - Blocks: true, - BTime: true, - Gen: true, - DataVersion: true, - }, - &Attr{ - Mode: Exec, - UID: 2, - GID: 3, - NLink: 4, - RDev: 5, - Size: 6, - BlockSize: 7, - Blocks: 8, - ATimeSeconds: 9, - ATimeNanoSeconds: 10, - MTimeSeconds: 11, - MTimeNanoSeconds: 12, - CTimeSeconds: 13, - CTimeNanoSeconds: 14, - BTimeSeconds: 15, - BTimeNanoSeconds: 16, - Gen: 17, - DataVersion: 18, - }, - &SetAttrMask{ - Permissions: true, - UID: true, - GID: true, - Size: true, - ATime: true, - MTime: true, - CTime: true, - ATimeNotSystemTime: true, - MTimeNotSystemTime: true, - }, - &SetAttr{ - Permissions: 1, - UID: 2, - GID: 3, - Size: 4, - ATimeSeconds: 5, - ATimeNanoSeconds: 6, - MTimeSeconds: 7, - MTimeNanoSeconds: 8, - }, - &Dirent{ - QID: QID{Type: 1}, - Offset: 2, - Type: 3, - Name: "a", - }, - &Rlerror{ - Error: 1, - }, - &Tstatfs{ - FID: 1, - }, - &Rstatfs{ - FSStat: FSStat{Type: 1}, - }, - &Tlopen{ - FID: 1, - Flags: WriteOnly, - }, - &Rlopen{ - QID: QID{Type: 1}, - IoUnit: 2, - }, - &Tlconnect{ - FID: 1, - }, - &Rlconnect{}, - &Tlcreate{ - FID: 1, - Name: "a", - OpenFlags: 2, - Permissions: 3, - GID: 4, - }, - &Rlcreate{ - Rlopen{QID: QID{Type: 1}}, - }, - &Tsymlink{ - Directory: 1, - Name: "a", - Target: "b", - GID: 2, - }, - &Rsymlink{ - QID: QID{Type: 1}, - }, - &Tmknod{ - Directory: 1, - Name: "a", - Mode: 2, - Major: 3, - Minor: 4, - GID: 5, - }, - &Rmknod{ - QID: QID{Type: 1}, - }, - &Trename{ - FID: 1, - Directory: 2, - Name: "a", - }, - &Rrename{}, - &Treadlink{ - FID: 1, - }, - &Rreadlink{ - Target: "a", - }, - &Tgetattr{ - FID: 1, - AttrMask: AttrMask{Mode: true}, - }, - &Rgetattr{ - Valid: AttrMask{Mode: true}, - QID: QID{Type: 1}, - Attr: Attr{Mode: Write}, - }, - &Tsetattr{ - FID: 1, - Valid: SetAttrMask{Permissions: true}, - SetAttr: SetAttr{Permissions: Write}, - }, - &Rsetattr{}, - &Txattrwalk{ - FID: 1, - NewFID: 2, - Name: "a", - }, - &Rxattrwalk{ - Size: 1, - }, - &Txattrcreate{ - FID: 1, - Name: "a", - AttrSize: 2, - Flags: 3, - }, - &Rxattrcreate{}, - &Tgetxattr{ - FID: 1, - Name: "abc", - Size: 2, - }, - &Rgetxattr{ - Value: "xyz", - }, - &Tsetxattr{ - FID: 1, - Name: "abc", - Value: "xyz", - Flags: 2, - }, - &Rsetxattr{}, - &Treaddir{ - Directory: 1, - Offset: 2, - Count: 3, - }, - &Rreaddir{ - // Count must be sufficient to encode a dirent. - Count: 0x1a, - Entries: []Dirent{{QID: QID{Type: 2}}}, - }, - &Tfsync{ - FID: 1, - }, - &Rfsync{}, - &Tlink{ - Directory: 1, - Target: 2, - Name: "a", - }, - &Rlink{}, - &Tmkdir{ - Directory: 1, - Name: "a", - Permissions: 2, - GID: 3, - }, - &Rmkdir{ - QID: QID{Type: 1}, - }, - &Trenameat{ - OldDirectory: 1, - OldName: "a", - NewDirectory: 2, - NewName: "b", - }, - &Rrenameat{}, - &Tunlinkat{ - Directory: 1, - Name: "a", - Flags: 2, - }, - &Runlinkat{}, - &Tversion{ - MSize: 1, - Version: "a", - }, - &Rversion{ - MSize: 1, - Version: "a", - }, - &Tauth{ - AuthenticationFID: 1, - UserName: "a", - AttachName: "b", - UID: 2, - }, - &Rauth{ - QID: QID{Type: 1}, - }, - &Tattach{ - FID: 1, - Auth: Tauth{AuthenticationFID: 2}, - }, - &Rattach{ - QID: QID{Type: 1}, - }, - &Tflush{ - OldTag: 1, - }, - &Rflush{}, - &Twalk{ - FID: 1, - NewFID: 2, - Names: []string{"a"}, - }, - &Rwalk{ - QIDs: []QID{{Type: 1}}, - }, - &Tread{ - FID: 1, - Offset: 2, - Count: 3, - }, - &Rread{ - Data: []byte{'a'}, - }, - &Twrite{ - FID: 1, - Offset: 2, - Data: []byte{'a'}, - }, - &Rwrite{ - Count: 1, - }, - &Tclunk{ - FID: 1, - }, - &Rclunk{}, - &Tremove{ - FID: 1, - }, - &Rremove{}, - &Tflushf{ - FID: 1, - }, - &Rflushf{}, - &Twalkgetattr{ - FID: 1, - NewFID: 2, - Names: []string{"a"}, - }, - &Rwalkgetattr{ - QIDs: []QID{{Type: 1}}, - Valid: AttrMask{Mode: true}, - Attr: Attr{Mode: Write}, - }, - &Tucreate{ - Tlcreate: Tlcreate{ - FID: 1, - Name: "a", - OpenFlags: 2, - Permissions: 3, - GID: 4, - }, - UID: 5, - }, - &Rucreate{ - Rlcreate{Rlopen{QID: QID{Type: 1}}}, - }, - &Tumkdir{ - Tmkdir: Tmkdir{ - Directory: 1, - Name: "a", - Permissions: 2, - GID: 3, - }, - UID: 4, - }, - &Rumkdir{ - Rmkdir{QID: QID{Type: 1}}, - }, - &Tusymlink{ - Tsymlink: Tsymlink{ - Directory: 1, - Name: "a", - Target: "b", - GID: 2, - }, - UID: 3, - }, - &Rusymlink{ - Rsymlink{QID: QID{Type: 1}}, - }, - &Tumknod{ - Tmknod: Tmknod{ - Directory: 1, - Name: "a", - Mode: 2, - Major: 3, - Minor: 4, - GID: 5, - }, - UID: 6, - }, - &Rumknod{ - Rmknod{QID: QID{Type: 1}}, - }, - &Tsetattrclunk{ - FID: 1, - Valid: SetAttrMask{ - Permissions: true, - UID: true, - GID: true, - Size: true, - ATime: true, - MTime: true, - CTime: true, - ATimeNotSystemTime: true, - MTimeNotSystemTime: true, - }, - SetAttr: SetAttr{ - Permissions: 1, - UID: 2, - GID: 3, - Size: 4, - ATimeSeconds: 5, - ATimeNanoSeconds: 6, - MTimeSeconds: 7, - MTimeNanoSeconds: 8, - }, - }, - } - - for _, enc := range objs { - // Encode the original. - data := make([]byte, initialBufferLength) - buf := buffer{data: data[:0]} - enc.encode(&buf) - - // Create a new object, same as the first. - enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder) - buf2 := buffer{data: buf.data} - - // To be fair, we need to add any payloads (directly). - if pl, ok := enc.(payloader); ok { - enc2.(payloader).SetPayload(pl.Payload()) - } - - // And any file payloads (directly). - if fl, ok := enc.(filer); ok { - enc2.(filer).SetFilePayload(fl.FilePayload()) - } - - // Mark sure it was okay. - enc2.decode(&buf2) - if buf2.isOverrun() { - t.Errorf("object %#v->%#v got overrun on decode", enc, enc2) - continue - } - - // Check that they are equal. - if !reflect.DeepEqual(enc, enc2) { - t.Errorf("object %#v and %#v differ", enc, enc2) - continue - } - } -} - -func TestMessageStrings(t *testing.T) { - for typ := range msgRegistry.factories { - entry := &msgRegistry.factories[typ] - if entry.create != nil { - name := fmt.Sprintf("%+v", typ) - t.Run(name, func(t *testing.T) { - defer func() { // Ensure no panic. - if r := recover(); r != nil { - t.Errorf("printing %s failed: %v", name, r) - } - }() - m := entry.create() - _ = fmt.Sprintf("%v", m) - err := ErrInvalidMsgType{MsgType(typ)} - _ = err.Error() - }) - } - } -} - -func TestRegisterDuplicate(t *testing.T) { - defer func() { - if r := recover(); r == nil { - // We expect a panic. - t.FailNow() - } - }() - - // Register a duplicate. - msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) -} - -func TestMsgCache(t *testing.T) { - // Cache starts empty. - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Message can be created with an empty cache. - msg, err := msgRegistry.get(0, MsgRlerror) - if err != nil { - t.Errorf("msgRegistry.get(): %v", err) - } - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Check that message is added to the cache when returned. - msgRegistry.put(msg) - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Check that returned message is reused. - if got, err := msgRegistry.get(0, MsgRlerror); err != nil { - t.Errorf("msgRegistry.get(): %v", err) - } else if msg != got { - t.Errorf("Message not reused, got: %d, want: %d", got, msg) - } - - // Check that cache doesn't grow beyond max size. - for i := 0; i < maxCacheSize+1; i++ { - msgRegistry.put(&Rlerror{}) - } - if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } -} diff --git a/pkg/p9/p9_state_autogen.go b/pkg/p9/p9_state_autogen.go new file mode 100644 index 000000000..bc9b1bd57 --- /dev/null +++ b/pkg/p9/p9_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package p9 diff --git a/pkg/p9/p9_test.go b/pkg/p9/p9_test.go deleted file mode 100644 index 8dda6cc64..000000000 --- a/pkg/p9/p9_test.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "os" - "testing" -) - -func TestFileModeHelpers(t *testing.T) { - fns := map[FileMode]struct { - // name identifies the file mode. - name string - - // function is the function that should return true given the - // right FileMode. - function func(m FileMode) bool - }{ - ModeRegular: { - name: "regular", - function: FileMode.IsRegular, - }, - ModeDirectory: { - name: "directory", - function: FileMode.IsDir, - }, - ModeNamedPipe: { - name: "named pipe", - function: FileMode.IsNamedPipe, - }, - ModeCharacterDevice: { - name: "character device", - function: FileMode.IsCharacterDevice, - }, - ModeBlockDevice: { - name: "block device", - function: FileMode.IsBlockDevice, - }, - ModeSymlink: { - name: "symlink", - function: FileMode.IsSymlink, - }, - ModeSocket: { - name: "socket", - function: FileMode.IsSocket, - }, - } - for mode, info := range fns { - // Make sure the mode doesn't identify as anything but itself. - for testMode, testfns := range fns { - if mode != testMode && testfns.function(mode) { - t.Errorf("Mode %s returned true when asked if it was mode %s", info.name, testfns.name) - } - } - - // Make sure mode identifies as itself. - if !info.function(mode) { - t.Errorf("Mode %s returned false when asked if it was itself", info.name) - } - } -} - -func TestFileModeToQID(t *testing.T) { - for _, test := range []struct { - // name identifies the test. - name string - - // mode is the FileMode we start out with. - mode FileMode - - // want is the corresponding QIDType we expect. - want QIDType - }{ - { - name: "Directories are of type directory", - mode: ModeDirectory, - want: TypeDir, - }, - { - name: "Sockets are append-only files", - mode: ModeSocket, - want: TypeAppendOnly, - }, - { - name: "Named pipes are append-only files", - mode: ModeNamedPipe, - want: TypeAppendOnly, - }, - { - name: "Character devices are append-only files", - mode: ModeCharacterDevice, - want: TypeAppendOnly, - }, - { - name: "Symlinks are of type symlink", - mode: ModeSymlink, - want: TypeSymlink, - }, - { - name: "Regular files are of type regular", - mode: ModeRegular, - want: TypeRegular, - }, - { - name: "Block devices are regular files", - mode: ModeBlockDevice, - want: TypeRegular, - }, - } { - if qidType := test.mode.QIDType(); qidType != test.want { - t.Errorf("ModeToQID test %s failed: got %o, wanted %o", test.name, qidType, test.want) - } - } -} - -func TestP9ModeConverters(t *testing.T) { - for _, m := range []FileMode{ - ModeRegular, - ModeDirectory, - ModeCharacterDevice, - ModeBlockDevice, - ModeSocket, - ModeSymlink, - ModeNamedPipe, - } { - if mb := ModeFromOS(m.OSMode()); mb != m { - t.Errorf("Converting %o to OS.FileMode gives %o and is converted back as %o", m, m.OSMode(), mb) - } - } -} - -func TestOSModeConverters(t *testing.T) { - // Modes that can be converted back and forth. - for _, m := range []os.FileMode{ - 0, // Regular file. - os.ModeDir, - os.ModeCharDevice | os.ModeDevice, - os.ModeDevice, - os.ModeSocket, - os.ModeSymlink, - os.ModeNamedPipe, - } { - if mb := ModeFromOS(m).OSMode(); mb != m { - t.Errorf("Converting %o to p9.FileMode gives %o and is converted back as %o", m, ModeFromOS(m), mb) - } - } - - // Modes that will be converted to a regular file since p9 cannot - // express these. - for _, m := range []os.FileMode{ - os.ModeAppend, - os.ModeExclusive, - os.ModeTemporary, - } { - if p9Mode := ModeFromOS(m); p9Mode != ModeRegular { - t.Errorf("Converting %o to p9.FileMode should have given ModeRegular, but yielded %o", m, p9Mode) - } - } -} - -func TestAttrMaskContains(t *testing.T) { - req := AttrMask{Mode: true, Size: true} - have := AttrMask{} - if have.Contains(req) { - t.Fatalf("AttrMask %v should not be a superset of %v", have, req) - } - have.Mode = true - if have.Contains(req) { - t.Fatalf("AttrMask %v should not be a superset of %v", have, req) - } - have.Size = true - have.MTime = true - if !have.Contains(req) { - t.Fatalf("AttrMask %v should be a superset of %v", have, req) - } -} diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD deleted file mode 100644 index 9c1ada0cb..000000000 --- a/pkg/p9/p9test/BUILD +++ /dev/null @@ -1,90 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "go_library", "go_test") - -package(licenses = ["notice"]) - -alias( - name = "mockgen", - actual = "@com_github_golang_mock//mockgen:mockgen", -) - -MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9" - -# mockgen_reflect is a source file that contains mock generation code that -# imports the p9 package and generates a specification via reflection. The -# usual generation path must be split into two distinct parts because the full -# source tree is not available to all build targets. Only declared depencies -# are available (and even then, not the Go source files). -genrule( - name = "mockgen_reflect", - testonly = 1, - outs = ["mockgen_reflect.go"], - cmd = ( - "$(location :mockgen) " + - "-package p9test " + - "-prog_only " + MOCK_SRC_PACKAGE + " " + - "Attacher,File > $@" - ), - tools = [":mockgen"], -) - -# mockgen_exec is the binary that includes the above reflection generator. -# Running this binary will emit an encoded version of the p9 Attacher and File -# structures. This is consumed by the mocks genrule, below. -go_binary( - name = "mockgen_exec", - testonly = 1, - srcs = ["mockgen_reflect.go"], - deps = [ - "//pkg/p9", - "@com_github_golang_mock//mockgen/model:go_default_library", - ], -) - -# mocks consumes the encoded output above, and generates the full source for a -# set of mocks. These are included directly in the p9test library. -genrule( - name = "mocks", - testonly = 1, - outs = ["mocks.go"], - cmd = ( - "$(location :mockgen) " + - "-package p9test " + - "-exec_only $(location :mockgen_exec) " + MOCK_SRC_PACKAGE + " File > $@" - ), - tools = [ - ":mockgen", - ":mockgen_exec", - ], -) - -go_library( - name = "p9test", - srcs = [ - "mocks.go", - "p9test.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/p9", - "//pkg/sync", - "//pkg/unet", - "@com_github_golang_mock//gomock:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "client_test", - size = "medium", - srcs = ["client_test.go"], - library = ":p9test", - deps = [ - "//pkg/fd", - "//pkg/p9", - "//pkg/sync", - "@com_github_golang_mock//gomock:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go deleted file mode 100644 index bb77e8e5f..000000000 --- a/pkg/p9/p9test/client_test.go +++ /dev/null @@ -1,2252 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9test - -import ( - "bytes" - "fmt" - "io" - "math/rand" - "os" - "reflect" - "strings" - "testing" - "time" - - "github.com/golang/mock/gomock" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sync" -) - -func TestPanic(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(nil)(nil) - defer d.Close() // Needed manually. - h.Attacher.EXPECT().Attach().Return(d, nil).Do(func() { - // Panic here, and ensure that we get back EFAULT. - panic("handler") - }) - - // Attach to the client. - if _, err := c.Attach("/"); err != unix.EFAULT { - t.Fatalf("got attach err %v, want EFAULT", err) - } -} - -func TestAttachNoLeak(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(nil)(nil) - h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) - - // Attach to the client. - f, err := c.Attach("/") - if err != nil { - t.Fatalf("got attach err %v, want nil", err) - } - - // Don't close the file. This should be closed automatically when the - // client disconnects. The mock asserts that everything is closed - // exactly once. This statement just removes the unused variable error. - _ = f -} - -func TestBadAttach(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Return an error on attach. - h.Attacher.EXPECT().Attach().Return(nil, unix.EINVAL).Times(1) - - // Attach to the client. - if _, err := c.Attach("/"); err != unix.EINVAL { - t.Fatalf("got attach err %v, want unix.EINVAL", err) - } -} - -func TestWalkAttach(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(map[string]Generator{ - "a": h.NewDirectory(map[string]Generator{ - "b": h.NewFile(), - }), - })(nil) - h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) - - // Attach to the client as a non-root, and ensure that the walk above - // occurs as expected. We should get back b, and all references should - // be dropped when the file is closed. - f, err := c.Attach("/a/b") - if err != nil { - t.Fatalf("got attach err %v, want nil", err) - } - defer f.Close() - - // Check that's a regular file. - if _, _, attr, err := f.GetAttr(p9.AttrMaskAll()); err != nil { - t.Errorf("got err %v, want nil", err) - } else if !attr.Mode.IsRegular() { - t.Errorf("got mode %v, want regular file", err) - } -} - -// newTypeMap returns a new type map dictionary. -func newTypeMap(h *Harness) map[string]Generator { - return map[string]Generator{ - "directory": h.NewDirectory(map[string]Generator{}), - "file": h.NewFile(), - "symlink": h.NewSymlink(), - "block-device": h.NewBlockDevice(), - "character-device": h.NewCharacterDevice(), - "named-pipe": h.NewNamedPipe(), - "socket": h.NewSocket(), - } -} - -// newRoot returns a new root filesystem. -// -// This is set up in a deterministic way for testing most operations. -// -// The represented file system looks like: -// - file -// - symlink -// - directory -// ... -// + one -// - file -// - symlink -// - directory -// ... -// + two -// - file -// - symlink -// - directory -// ... -// + three -// - file -// - symlink -// - directory -// ... -func newRoot(h *Harness, c *p9.Client) (*Mock, p9.File) { - root := newTypeMap(h) - one := newTypeMap(h) - two := newTypeMap(h) - three := newTypeMap(h) - one["two"] = h.NewDirectory(two) // Will be nested in one. - root["one"] = h.NewDirectory(one) // Top level. - root["three"] = h.NewDirectory(three) // Alternate top-level. - - // Create a new root. - rootBackend := h.NewDirectory(root)(nil) - h.Attacher.EXPECT().Attach().Return(rootBackend, nil) - - // Attach to the client. - r, err := c.Attach("/") - if err != nil { - h.t.Fatalf("got attach err %v, want nil", err) - } - - return rootBackend, r -} - -func allInvalidNames(from string) []string { - return []string{ - from + "/other", - from + "/..", - from + "/.", - from + "/", - "other/" + from, - "/" + from, - "./" + from, - "../" + from, - ".", - "..", - "/", - "", - } -} - -func TestWalkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Run relevant tests. - for name := range newTypeMap(h) { - // These are all the various ways that one might attempt to - // construct compound paths. They should all be rejected, as - // any compound that contains a / is not allowed, as well as - // the singular paths of '.' and '..'. - if _, _, err := root.Walk([]string{".", name}); err != unix.EINVAL { - t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{"..", name}); err != unix.EINVAL { - t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{name, "."}); err != unix.EINVAL { - t.Errorf("Walk through %s . wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{name, ".."}); err != unix.EINVAL { - t.Errorf("Walk through %s .. wanted EINVAL, got %v", name, err) - } - for _, invalidName := range allInvalidNames(name) { - if _, _, err := root.Walk([]string{invalidName}); err != unix.EINVAL { - t.Errorf("Walk through %s wanted EINVAL, got %v", invalidName, err) - } - } - wantErr := unix.EINVAL - if name == "directory" { - // We can attempt a walk through a directory. However, - // we should never see a file named "other", so we - // expect this to return ENOENT. - wantErr = unix.ENOENT - } - if _, _, err := root.Walk([]string{name, "other"}); err != wantErr { - t.Errorf("Walk through %s/other wanted %v, got %v", name, wantErr, err) - } - - // Do a successful walk. - _, f, err := root.Walk([]string{name}) - if err != nil { - t.Errorf("Walk to %s wanted nil, got %v", name, err) - } - defer f.Close() - local := h.Pop(f) - - // Check that the file matches. - _, localMask, localAttr, localErr := local.GetAttr(p9.AttrMaskAll()) - if _, mask, attr, err := f.GetAttr(p9.AttrMaskAll()); mask != localMask || attr != localAttr || err != localErr { - t.Errorf("GetAttr got (%v, %v, %v), wanted (%v, %v, %v)", - mask, attr, err, localMask, localAttr, localErr) - } - - // Ensure we can't walk backwards. - if _, _, err := f.Walk([]string{"."}); err != unix.EINVAL { - t.Errorf("Walk through %s/. wanted EINVAL, got %v", name, err) - } - if _, _, err := f.Walk([]string{".."}); err != unix.EINVAL { - t.Errorf("Walk through %s/.. wanted EINVAL, got %v", name, err) - } - } -} - -// fileGenerator is a function to generate files via walk or create. -// -// Examples are: -// - walkHelper -// - walkAndOpenHelper -// - createHelper -type fileGenerator func(*Harness, string, p9.File) (*Mock, *Mock, p9.File) - -// walkHelper walks to the given file. -// -// The backends of the parent and walked file are returned, as well as the -// walked client file. -func walkHelper(h *Harness, name string, dir p9.File) (parentBackend *Mock, walkedBackend *Mock, walked p9.File) { - _, parent, err := dir.Walk(nil) - if err != nil { - h.t.Fatalf("Walk(nil) got err %v, want nil", err) - } - defer parent.Close() - parentBackend = h.Pop(parent) - - _, walked, err = parent.Walk([]string{name}) - if err != nil { - h.t.Fatalf("Walk(%s) got err %v, want nil", name, err) - } - walkedBackend = h.Pop(walked) - - return parentBackend, walkedBackend, walked -} - -// walkAndOpenHelper additionally opens the walked file, if possible. -func walkAndOpenHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { - parentBackend, walkedBackend, walked := walkHelper(h, name, dir) - if p9.CanOpen(walkedBackend.Attr.Mode) { - // Open for all file types that we can. We stick to a read-only - // open here because directories may not be opened otherwise. - walkedBackend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := walked.Open(p9.ReadOnly); err != nil { - h.t.Errorf("got open err %v, want nil", err) - } - } else { - // ... or assert an error for others. - if _, _, _, err := walked.Open(p9.ReadOnly); err != unix.EINVAL { - h.t.Errorf("got open err %v, want EINVAL", err) - } - } - return parentBackend, walkedBackend, walked -} - -// createHelper creates the given file and returns the parent directory, -// created file and client file, which must be closed when done. -func createHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { - // Clone the directory first, since Create replaces the existing file. - // We change the type after calling create. - _, dirThenFile, err := dir.Walk(nil) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - - // Create a new server-side file. On the server-side, the a new file is - // returned from a create call. The client will reuse the same file, - // but we still expect the normal chain of closes. This complicates - // things a bit because the "parent" will always chain to the cloned - // dir above. - dirBackend := h.Pop(dirThenFile) // New backend directory. - newFile := h.NewFile()(dirBackend) // New file with backend parent. - dirBackend.EXPECT().Create(name, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, newFile, newFile.QID, uint32(0), nil) - - // Create via the client. - _, dirThenFile, _, _, err = dirThenFile.Create(name, p9.ReadOnly, 0, 0, 0) - if err != nil { - h.t.Fatalf("got create err %v, want nil", err) - } - - // Ensure subsequent walks succeed. - dirBackend.AddChild(name, h.NewFile()) - return dirBackend, newFile, dirThenFile -} - -// deprecatedRemover allows us to access the deprecated Remove operation within -// the p9.File client object. -type deprecatedRemover interface { - Remove() error -} - -// checkDeleted asserts that relevant methods fail for an unlinked file. -// -// This function will close the file at the end. -func checkDeleted(h *Harness, file p9.File) { - defer file.Close() // See doc. - - if _, _, _, err := file.Open(p9.ReadOnly); err != unix.EINVAL { - h.t.Errorf("open while deleted, got %v, want EINVAL", err) - } - if _, _, _, _, err := file.Create("created", p9.ReadOnly, 0, 0, 0); err != unix.EINVAL { - h.t.Errorf("create while deleted, got %v, want EINVAL", err) - } - if _, err := file.Symlink("old", "new", 0, 0); err != unix.EINVAL { - h.t.Errorf("symlink while deleted, got %v, want EINVAL", err) - } - // N.B. This link is technically invalid, but if a call to link is - // actually made in the backend then the mock will panic. - if err := file.Link(file, "new"); err != unix.EINVAL { - h.t.Errorf("link while deleted, got %v, want EINVAL", err) - } - if err := file.RenameAt("src", file, "dst"); err != unix.EINVAL { - h.t.Errorf("renameAt while deleted, got %v, want EINVAL", err) - } - if err := file.UnlinkAt("file", 0); err != unix.EINVAL { - h.t.Errorf("unlinkAt while deleted, got %v, want EINVAL", err) - } - if err := file.Rename(file, "dst"); err != unix.EINVAL { - h.t.Errorf("rename while deleted, got %v, want EINVAL", err) - } - if _, err := file.Readlink(); err != unix.EINVAL { - h.t.Errorf("readlink while deleted, got %v, want EINVAL", err) - } - if _, err := file.Mkdir("dir", p9.ModeDirectory, 0, 0); err != unix.EINVAL { - h.t.Errorf("mkdir while deleted, got %v, want EINVAL", err) - } - if _, err := file.Mknod("dir", p9.ModeDirectory, 0, 0, 0, 0); err != unix.EINVAL { - h.t.Errorf("mknod while deleted, got %v, want EINVAL", err) - } - if _, err := file.Readdir(0, 1); err != unix.EINVAL { - h.t.Errorf("readdir while deleted, got %v, want EINVAL", err) - } - if _, err := file.Connect(p9.ConnectFlags(0)); err != unix.EINVAL { - h.t.Errorf("connect while deleted, got %v, want EINVAL", err) - } - - // The remove method is technically deprecated, but we want to ensure - // that it still checks for deleted appropriately. We must first clone - // the file because remove is equivalent to close. - _, newFile, err := file.Walk(nil) - if err == unix.EBUSY { - // We can't walk from here because this reference is open - // already. Okay, we will also have unopened cases through - // TestUnlink, just skip the remove operation for now. - return - } else if err != nil { - h.t.Fatalf("clone failed, got %v, want nil", err) - } - if err := newFile.(deprecatedRemover).Remove(); err != unix.EINVAL { - h.t.Errorf("remove while deleted, got %v, want EINVAL", err) - } -} - -// deleter is a function to remove a file. -type deleter func(parent p9.File, name string) error - -// unlinkAt is a deleter. -func unlinkAt(parent p9.File, name string) error { - // Call unlink. Note that a filesystem may normally impose additional - // constaints on unlinkat success, such as ensuring that a directory is - // empty, requiring AT_REMOVEDIR in flags to remove a directory, etc. - // None of that is required internally (entire trees can be marked - // deleted when this operation succeeds), so the mock will succeed. - return parent.UnlinkAt(name, 0) -} - -// remove is a deleter. -func remove(parent p9.File, name string) error { - // See notes above re: remove. - _, newFile, err := parent.Walk([]string{name}) - if err != nil { - // Should not be expected. - return err - } - - // Do the actual remove. - if err := newFile.(deprecatedRemover).Remove(); err != nil { - return err - } - - // Ensure that the remove closed the file. - if err := newFile.(deprecatedRemover).Remove(); err != unix.EBADF { - return unix.EBADF // Propagate this code. - } - - return nil -} - -// unlinkHelper unlinks the noted path, and ensures that all relevant -// operations on that path, acquired from multiple paths, start failing. -func unlinkHelper(h *Harness, root p9.File, targetNames []string, targetGen fileGenerator, deleteFn deleter) { - // name is the file to be unlinked. - name := targetNames[len(targetNames)-1] - - // Walk to the directory containing the target. - _, parent, err := root.Walk(targetNames[:len(targetNames)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer parent.Close() - parentBackend := h.Pop(parent) - - // Walk to or generate the target file. - _, _, target := targetGen(h, name, parent) - defer checkDeleted(h, target) - - // Walk to a second reference. - _, second, err := parent.Walk([]string{name}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer checkDeleted(h, second) - - // Walk to a third reference, from the start. - _, third, err := root.Walk(targetNames) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer checkDeleted(h, third) - - // This will be translated in the backend to an unlinkat. - parentBackend.EXPECT().UnlinkAt(name, uint32(0)).Return(nil) - - // Actually perform the deletion. - if err := deleteFn(parent, name); err != nil { - h.t.Fatalf("got delete err %v, want nil", err) - } -} - -func unlinkTest(t *testing.T, targetNames []string, targetGen fileGenerator) { - t.Run(fmt.Sprintf("unlinkAt(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - unlinkHelper(h, root, targetNames, targetGen, unlinkAt) - }) - t.Run(fmt.Sprintf("remove(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - unlinkHelper(h, root, targetNames, targetGen, remove) - }) -} - -func TestUnlink(t *testing.T) { - // Unlink all files. - for name := range newTypeMap(nil) { - unlinkTest(t, []string{name}, walkHelper) - unlinkTest(t, []string{name}, walkAndOpenHelper) - unlinkTest(t, []string{"one", name}, walkHelper) - unlinkTest(t, []string{"one", name}, walkAndOpenHelper) - unlinkTest(t, []string{"one", "two", name}, walkHelper) - unlinkTest(t, []string{"one", "two", name}, walkAndOpenHelper) - } - - // Unlink a directory. - unlinkTest(t, []string{"one"}, walkHelper) - unlinkTest(t, []string{"one"}, walkAndOpenHelper) - unlinkTest(t, []string{"one", "two"}, walkHelper) - unlinkTest(t, []string{"one", "two"}, walkAndOpenHelper) - - // Unlink created files. - unlinkTest(t, []string{"created"}, createHelper) - unlinkTest(t, []string{"one", "created"}, createHelper) - unlinkTest(t, []string{"one", "two", "created"}, createHelper) -} - -func TestUnlinkAtInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.UnlinkAt(invalidName, 0); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -// expectRenamed asserts an ordered sequence of rename calls, based on all the -// elements in elements being the source, and the first element therein -// changing to dstName, parented at dstParent. -func expectRenamed(file *Mock, elements []string, dstParent *Mock, dstName string) *gomock.Call { - if len(elements) > 0 { - // Recurse to the parent, if necessary. - call := expectRenamed(file.parent, elements[:len(elements)-1], dstParent, dstName) - - // Recursive case: this element is unchanged, but should have - // it's hook called after the parent. - return file.EXPECT().Renamed(file.parent, elements[len(elements)-1]).Do(func(p p9.File, _ string) { - file.parent = p.(*Mock) - }).After(call) - } - - // Base case: this is the changed element. - return file.EXPECT().Renamed(dstParent, dstName).Do(func(p p9.File, name string) { - file.parent = p.(*Mock) - }) -} - -// renamer is a rename function. -type renamer func(h *Harness, srcParent, dstParent p9.File, origName, newName string, selfRename bool) error - -// renameAt is a renamer. -func renameAt(_ *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { - return srcParent.RenameAt(srcName, dstParent, dstName) -} - -// rename is a renamer. -func rename(h *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { - _, f, err := srcParent.Walk([]string{srcName}) - if err != nil { - return err - } - defer f.Close() - if !selfRename { - backend := h.Pop(f) - backend.EXPECT().Renamed(gomock.Any(), dstName).Do(func(p p9.File, name string) { - backend.parent = p.(*Mock) // Required for close ordering. - }) - } - return f.Rename(dstParent, dstName) -} - -// renameHelper executes a rename, and asserts that all relevant elements -// receive expected notifications. If overwriting a file, this includes -// ensuring that the target has been appropriately marked as unlinked. -func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string, target fileGenerator, renameFn renamer) { - // Walk to the directory containing the target. - srcQID, targetParent, err := root.Walk(srcNames[:len(srcNames)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer targetParent.Close() - targetParentBackend := h.Pop(targetParent) - - // Walk to or generate the target file. - _, targetBackend, src := target(h, srcNames[len(srcNames)-1], targetParent) - defer src.Close() - - // Walk to a second reference. - _, second, err := targetParent.Walk([]string{srcNames[len(srcNames)-1]}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer second.Close() - secondBackend := h.Pop(second) - - // Walk to a third reference, from the start. - _, third, err := root.Walk(srcNames) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer third.Close() - thirdBackend := h.Pop(third) - - // Find the common suffix to identify the rename parent. - var ( - renameDestPath []string - renameSrcPath []string - selfRename bool - ) - for i := 1; i <= len(srcNames) && i <= len(dstNames); i++ { - if srcNames[len(srcNames)-i] != dstNames[len(dstNames)-i] { - // Take the full prefix of dstNames up until this - // point, including the first mismatched name. The - // first mismatch must be the renamed entry. - renameDestPath = dstNames[:len(dstNames)-i+1] - renameSrcPath = srcNames[:len(srcNames)-i+1] - - // Does the renameDestPath fully contain the - // renameSrcPath here? If yes, then this is a mismatch. - // We can't rename the src to some subpath of itself. - if len(renameDestPath) > len(renameSrcPath) && - reflect.DeepEqual(renameDestPath[:len(renameSrcPath)], renameSrcPath) { - renameDestPath = nil - renameSrcPath = nil - continue - } - break - } - } - if len(renameSrcPath) == 0 || len(renameDestPath) == 0 { - // This must be a rename to self, or a tricky look-alike. This - // happens iff we fail to find a suitable divergence in the two - // paths. It's a true self move if the path length is the same. - renameDestPath = dstNames - renameSrcPath = srcNames - selfRename = len(srcNames) == len(dstNames) - } - - // Walk to the source parent. - _, srcParent, err := root.Walk(renameSrcPath[:len(renameSrcPath)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer srcParent.Close() - srcParentBackend := h.Pop(srcParent) - - // Walk to the destination parent. - _, dstParent, err := root.Walk(renameDestPath[:len(renameDestPath)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer dstParent.Close() - dstParentBackend := h.Pop(dstParent) - - // expectedErr is the result of the rename operation. - var expectedErr error - - // Walk to the target file, if one exists. - dstQID, dst, err := root.Walk(renameDestPath) - if err == nil { - if !selfRename && srcQID[0].Type == dstQID[0].Type { - // If there is a destination file, and is it of the - // same type as the source file, then we expect the - // rename to succeed. We expect the destination file to - // be deleted, so we run a deletion test on it in this - // case. - defer checkDeleted(h, dst) - } else { - // If the type is different than the destination, then - // we expect the rename to fail. We expect that this - // is returned. - // - // If the file being renamed to itself, this is - // technically allowed and a no-op, but all the - // triggers will fire. - if !selfRename { - expectedErr = unix.EINVAL - } - dst.Close() - } - } - dstName := renameDestPath[len(renameDestPath)-1] // Renamed element. - srcName := renameSrcPath[len(renameSrcPath)-1] // Renamed element. - if expectedErr == nil && !selfRename { - // Expect all to be renamed appropriately. Note that if this is - // a final file being renamed, then we expect the file to be - // called with the new parent. If not, then we expect the - // rename hook to be called, but the parent will remain - // unchanged. - elements := srcNames[len(renameSrcPath):] - expectRenamed(targetBackend, elements, dstParentBackend, dstName) - expectRenamed(secondBackend, elements, dstParentBackend, dstName) - expectRenamed(thirdBackend, elements, dstParentBackend, dstName) - - // The target parent has also been opened, and may be moved - // directly or indirectly. - if len(elements) > 1 { - expectRenamed(targetParentBackend, elements[:len(elements)-1], dstParentBackend, dstName) - } - } - - // Expect the rename if it's not the same file. Note that like unlink, - // renames are always translated to the at variant in the backend. - if !selfRename { - srcParentBackend.EXPECT().RenameAt(srcName, dstParentBackend, dstName).Return(expectedErr) - } - - // Perform the actual rename; everything has been lined up. - if err := renameFn(h, srcParent, dstParent, srcName, dstName, selfRename); err != expectedErr { - h.t.Fatalf("got rename err %v, want %v", err, expectedErr) - } -} - -func renameTest(t *testing.T, srcNames []string, dstNames []string, target fileGenerator) { - t.Run(fmt.Sprintf("renameAt(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - renameHelper(h, root, srcNames, dstNames, target, renameAt) - }) - t.Run(fmt.Sprintf("rename(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - renameHelper(h, root, srcNames, dstNames, target, rename) - }) -} - -func TestRename(t *testing.T) { - // In-directory rename, simple case. - for name := range newTypeMap(nil) { - // Within the root. - renameTest(t, []string{name}, []string{"renamed"}, walkHelper) - renameTest(t, []string{name}, []string{"renamed"}, walkAndOpenHelper) - - // Within a subdirectory. - renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"created"}, []string{"renamed"}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"one", "renamed"}, createHelper) - - // Across directories. - for name := range newTypeMap(nil) { - // Down one level. - renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkAndOpenHelper) - - // Up one level. - renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkHelper) - renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkAndOpenHelper) - - // Across at the same level. - renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"one", "created"}, []string{"one", "two", "renamed"}, createHelper) - renameTest(t, []string{"one", "two", "created"}, []string{"one", "renamed"}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"three", "renamed"}, createHelper) - - // Renaming parents. - for name := range newTypeMap(nil) { - // Rename a parent. - renameTest(t, []string{"one", name}, []string{"renamed", name}, walkHelper) - renameTest(t, []string{"one", name}, []string{"renamed", name}, walkAndOpenHelper) - - // Rename a super parent. - renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkHelper) - renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"one", "created"}, []string{"renamed", "created"}, createHelper) - renameTest(t, []string{"one", "two", "created"}, []string{"renamed", "created"}, createHelper) - - // Over existing files, including itself. - for name := range newTypeMap(nil) { - for other := range newTypeMap(nil) { - // Overwrite the noted file (may be itself). - renameTest(t, []string{"one", name}, []string{"one", other}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", other}, walkAndOpenHelper) - - // Overwrite other files in another directory. - renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkAndOpenHelper) - } - - // Overwrite by moving the parent. - renameTest(t, []string{"three", name}, []string{"one", name}, walkHelper) - renameTest(t, []string{"three", name}, []string{"one", name}, walkAndOpenHelper) - - // Create over the types. - renameTest(t, []string{"one", "created"}, []string{"one", name}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"one", "two", name}, createHelper) - renameTest(t, []string{"three", "created"}, []string{"one", name}, createHelper) - } -} - -func TestRenameInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.Rename(root, invalidName); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestRenameAtInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.RenameAt(invalidName, root, "okay"); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - if err := root.RenameAt("okay", root, invalidName); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -// TestRenameSecondOrder tests that indirect rename targets continue to receive -// Renamed calls after a rename of its renamed parent. i.e., -// -// 1. Create /one/file -// 2. Create /directory -// 3. Rename /one -> /directory/one -// 4. Rename /directory -> /three/foo -// 5. file from (1) should still receive Renamed. -// -// This is a regression test for b/135219260. -func TestRenameSecondOrder(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - rootBackend, root := newRoot(h, c) - defer root.Close() - - // Walk to /one. - _, oneBackend, oneFile := walkHelper(h, "one", root) - defer oneFile.Close() - - // Walk to and generate /one/file. - // - // walkHelper re-walks to oneFile, so we need the second backend, - // which will also receive Renamed calls. - oneSecondBackend, fileBackend, fileFile := walkHelper(h, "file", oneFile) - defer fileFile.Close() - - // Walk to and generate /directory. - _, directoryBackend, directoryFile := walkHelper(h, "directory", root) - defer directoryFile.Close() - - // Rename /one to /directory/one. - rootBackend.EXPECT().RenameAt("one", directoryBackend, "one").Return(nil) - expectRenamed(oneBackend, []string{}, directoryBackend, "one") - expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one") - expectRenamed(fileBackend, []string{}, oneBackend, "file") - if err := renameAt(h, root, directoryFile, "one", "one", false); err != nil { - h.t.Fatalf("got rename err %v, want nil", err) - } - - // Walk to /three. - _, threeBackend, threeFile := walkHelper(h, "three", root) - defer threeFile.Close() - - // Rename /directory to /three/foo. - rootBackend.EXPECT().RenameAt("directory", threeBackend, "foo").Return(nil) - expectRenamed(directoryBackend, []string{}, threeBackend, "foo") - expectRenamed(oneBackend, []string{}, directoryBackend, "one") - expectRenamed(oneSecondBackend, []string{}, directoryBackend, "one") - expectRenamed(fileBackend, []string{}, oneBackend, "file") - if err := renameAt(h, root, threeFile, "directory", "foo", false); err != nil { - h.t.Fatalf("got rename err %v, want nil", err) - } -} - -func TestReadlink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, f, err := root.Walk([]string{name}) - if err != nil { - t.Fatalf("walk failed: got %v, wanted nil", err) - } - defer f.Close() - backend := h.Pop(f) - - const symlinkTarget = "symlink-target" - - if backend.Attr.Mode.IsSymlink() { - // This should only go through on symlinks. - backend.EXPECT().Readlink().Return(symlinkTarget, nil) - } - - // Attempt a Readlink operation. - target, err := f.Readlink() - if err != nil && err != unix.EINVAL { - t.Errorf("readlink got %v, wanted EINVAL", err) - } else if err == nil && target != symlinkTarget { - t.Errorf("readlink got %v, wanted %v", target, symlinkTarget) - } - }) - } -} - -// fdTest is a wrapper around operations that may send file descriptors. This -// asserts that the file descriptors are working as intended. -func fdTest(t *testing.T, sendFn func(*fd.FD) *fd.FD) { - // Create a pipe that we can read from. - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("unable to create pipe: %v", err) - } - defer r.Close() - defer w.Close() - - // Attempt to send the write end. - wFD, err := fd.NewFromFile(w) - if err != nil { - t.Fatalf("unable to convert file: %v", err) - } - defer wFD.Close() // This is a copy. - - // Send wFD and receive newFD. - newFD := sendFn(wFD) - defer newFD.Close() - - // Attempt to write. - const message = "hello" - if _, err := newFD.Write([]byte(message)); err != nil { - t.Fatalf("write got %v, wanted nil", err) - } - - // Should see the message on our end. - buffer := []byte(message) - if _, err := io.ReadFull(r, buffer); err != nil { - t.Fatalf("read got %v, wanted nil", err) - } - if string(buffer) != message { - t.Errorf("got message %v, wanted %v", string(buffer), message) - } -} - -func TestConnect(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Catch all the non-socket cases. - if !backend.Attr.Mode.IsSocket() { - // This has been set up to fail if Connect is called. - if _, err := f.Connect(p9.ConnectFlags(0)); err != unix.EINVAL { - t.Errorf("connect got %v, wanted EINVAL", err) - } - return - } - - // Ensure the fd exchange works. - fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Connect(p9.ConnectFlags(0)).Return(send, nil) - recv, err := backend.Connect(p9.ConnectFlags(0)) - if err != nil { - t.Fatalf("connect got %v, wanted nil", err) - } - return recv - }) - }) - } -} - -func TestReaddir(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Catch all the non-directory cases. - if !backend.Attr.Mode.IsDir() { - // This has also been set up to fail if Readdir is called. - if _, err := f.Readdir(0, 1); err != unix.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - return - } - - // Ensure that readdir works for directories. - if _, err := f.Readdir(0, 1); err != unix.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - if _, _, _, err := f.Open(p9.ReadWrite); err != unix.EISDIR { - t.Errorf("readdir got %v, wanted EISDIR", err) - } - if _, _, _, err := f.Open(p9.WriteOnly); err != unix.EISDIR { - t.Errorf("readdir got %v, wanted EISDIR", err) - } - backend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := f.Open(p9.ReadOnly); err != nil { - t.Errorf("readdir got %v, wanted nil", err) - } - backend.EXPECT().Readdir(uint64(0), uint32(1)).Times(1) - if _, err := f.Readdir(0, 1); err != nil { - t.Errorf("readdir got %v, wanted nil", err) - } - }) - } -} - -func TestOpen(t *testing.T) { - type openTest struct { - name string - flags p9.OpenFlags - err error - match func(p9.FileMode) bool - } - - cases := []openTest{ - { - name: "not-openable-read-only", - flags: p9.ReadOnly, - err: unix.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "not-openable-write-only", - flags: p9.WriteOnly, - err: unix.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "not-openable-read-write", - flags: p9.ReadWrite, - err: unix.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "directory-read-only", - flags: p9.ReadOnly, - err: nil, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "directory-read-write", - flags: p9.ReadWrite, - err: unix.EISDIR, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "directory-write-only", - flags: p9.WriteOnly, - err: unix.EISDIR, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "read-only", - flags: p9.ReadOnly, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - }, - { - name: "write-only", - flags: p9.WriteOnly, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "read-write", - flags: p9.ReadWrite, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "directory-read-only-truncate", - flags: p9.ReadOnly | p9.OpenTruncate, - err: unix.EISDIR, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "read-only-truncate", - flags: p9.ReadOnly | p9.OpenTruncate, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "write-only-truncate", - flags: p9.WriteOnly | p9.OpenTruncate, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "read-write-truncate", - flags: p9.ReadWrite | p9.OpenTruncate, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - } - - // Open(flags OpenFlags) (*fd.FD, QID, uint32, error) - // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice - // - returning a file works as expected - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Does this match the case? - if !tc.match(backend.Attr.Mode) { - t.SkipNow() - } - - // Ensure open-required operations fail. - if _, err := f.ReadAt([]byte("hello"), 0); err != unix.EINVAL { - t.Errorf("readAt got %v, wanted EINVAL", err) - } - if _, err := f.WriteAt(make([]byte, 6), 0); err != unix.EINVAL { - t.Errorf("writeAt got %v, wanted EINVAL", err) - } - if err := f.FSync(); err != unix.EINVAL { - t.Errorf("fsync got %v, wanted EINVAL", err) - } - if _, err := f.Readdir(0, 1); err != unix.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - - // Attempt the given open. - if tc.err != nil { - // We expect an error, just test and return. - if _, _, _, err := f.Open(tc.flags); err != tc.err { - t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) - } - return - } - - // Run an FD test, since we expect success. - fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1) - recv, _, _, err := f.Open(tc.flags) - if err != tc.err { - t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) - } - return recv - }) - - // If the open was successful, attempt another one. - if _, _, _, err := f.Open(tc.flags); err != unix.EINVAL { - t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err) - } - - // Ensure that all illegal operations fail. - if _, _, err := f.Walk(nil); err != unix.EINVAL && err != unix.EBUSY { - t.Errorf("walk got %v, wanted EINVAL or EBUSY", err) - } - if _, _, _, _, err := f.WalkGetAttr(nil); err != unix.EINVAL && err != unix.EBUSY { - t.Errorf("walkgetattr got %v, wanted EINVAL or EBUSY", err) - } - }) - } - } -} - -func TestClose(t *testing.T) { - type closeTest struct { - name string - closeFn func(backend *Mock, f p9.File) error - } - - cases := []closeTest{ - { - name: "close", - closeFn: func(_ *Mock, f p9.File) error { - return f.Close() - }, - }, - { - name: "remove", - closeFn: func(backend *Mock, f p9.File) error { - // Allow the rename call in the parent, automatically translated. - backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) - return f.(deprecatedRemover).Remove() - }, - }, - { - name: "setAttrClose", - closeFn: func(backend *Mock, f p9.File) error { - valid := p9.SetAttrMask{ATime: true} - attr := p9.SetAttr{ATimeSeconds: 1, ATimeNanoSeconds: 2} - backend.EXPECT().SetAttr(valid, attr).Times(1) - return f.SetAttrClose(valid, attr) - }, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s(%s)", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - - // Close via the prescribed method. - if err := tc.closeFn(backend, f); err != nil { - t.Fatalf("closeFn failed: %v", err) - } - - // Everything should fail with EBADF. - if _, _, err := f.Walk(nil); err != unix.EBADF { - t.Errorf("walk got %v, wanted EBADF", err) - } - if _, err := f.StatFS(); err != unix.EBADF { - t.Errorf("statfs got %v, wanted EBADF", err) - } - if _, _, _, err := f.GetAttr(p9.AttrMaskAll()); err != unix.EBADF { - t.Errorf("getattr got %v, wanted EBADF", err) - } - if err := f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}); err != unix.EBADF { - t.Errorf("setattrk got %v, wanted EBADF", err) - } - if err := f.Rename(root, "new-name"); err != unix.EBADF { - t.Errorf("rename got %v, wanted EBADF", err) - } - if err := f.Close(); err != unix.EBADF { - t.Errorf("close got %v, wanted EBADF", err) - } - if _, _, _, err := f.Open(p9.ReadOnly); err != unix.EBADF { - t.Errorf("open got %v, wanted EBADF", err) - } - if _, err := f.ReadAt([]byte("hello"), 0); err != unix.EBADF { - t.Errorf("readAt got %v, wanted EBADF", err) - } - if _, err := f.WriteAt(make([]byte, 6), 0); err != unix.EBADF { - t.Errorf("writeAt got %v, wanted EBADF", err) - } - if err := f.FSync(); err != unix.EBADF { - t.Errorf("fsync got %v, wanted EBADF", err) - } - if _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 0, 0); err != unix.EBADF { - t.Errorf("create got %v, wanted EBADF", err) - } - if _, err := f.Mkdir("new-directory", 0, 0, 0); err != unix.EBADF { - t.Errorf("mkdir got %v, wanted EBADF", err) - } - if _, err := f.Symlink("old-name", "new-name", 0, 0); err != unix.EBADF { - t.Errorf("symlink got %v, wanted EBADF", err) - } - if err := f.Link(root, "new-name"); err != unix.EBADF { - t.Errorf("link got %v, wanted EBADF", err) - } - if _, err := f.Mknod("new-block-device", 0, 0, 0, 0, 0); err != unix.EBADF { - t.Errorf("mknod got %v, wanted EBADF", err) - } - if err := f.RenameAt("old-name", root, "new-name"); err != unix.EBADF { - t.Errorf("renameAt got %v, wanted EBADF", err) - } - if err := f.UnlinkAt("name", 0); err != unix.EBADF { - t.Errorf("unlinkAt got %v, wanted EBADF", err) - } - if _, err := f.Readdir(0, 1); err != unix.EBADF { - t.Errorf("readdir got %v, wanted EBADF", err) - } - if _, err := f.Readlink(); err != unix.EBADF { - t.Errorf("readlink got %v, wanted EBADF", err) - } - if err := f.Flush(); err != unix.EBADF { - t.Errorf("flush got %v, wanted EBADF", err) - } - if _, _, _, _, err := f.WalkGetAttr(nil); err != unix.EBADF { - t.Errorf("walkgetattr got %v, wanted EBADF", err) - } - if _, err := f.Connect(p9.ConnectFlags(0)); err != unix.EBADF { - t.Errorf("connect got %v, wanted EBADF", err) - } - }) - } - } -} - -// onlyWorksOnOpenThings is a helper test method for operations that should -// only work on files that have been explicitly opened. -func onlyWorksOnOpenThings(h *Harness, t *testing.T, name string, root p9.File, mode p9.OpenFlags, expectedErr error, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Does it work before opening? - if err := fn(backend, f, false); err != unix.EINVAL { - t.Errorf("operation got %v, wanted EINVAL", err) - } - - // Is this openable? - if !p9.CanOpen(backend.Attr.Mode) { - return // Nothing to do. - } - - // If this is a directory, we can't handle writing. - if backend.Attr.Mode.IsDir() && (mode == p9.ReadWrite || mode == p9.WriteOnly) { - return // Skip. - } - - // Open the file. - backend.EXPECT().Open(mode) - if _, _, _, err := f.Open(mode); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - - // Attempt the operation. - if err := fn(backend, f, expectedErr == nil); err != expectedErr { - t.Fatalf("operation got %v, wanted %v", err, expectedErr) - } -} - -func TestRead(t *testing.T) { - type readTest struct { - name string - mode p9.OpenFlags - err error - } - - cases := []readTest{ - { - name: "read-only", - mode: p9.ReadOnly, - err: nil, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: unix.EPERM, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const message = "hello" - - onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, err := f.ReadAt([]byte(message), 0) - return err - } - - // Prepare for the call to readAt in the backend. - backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - copy(p, message) - }).Return(len(message), nil) - - // Make the client call. - p := make([]byte, 2*len(message)) // Double size. - n, err := f.ReadAt(p, 0) - - // Sanity check result. - if err != nil { - return err - } - if n != len(message) { - t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) - } - if !bytes.Equal(p[:n], []byte(message)) { - t.Fatalf("message incorrect, got %v, want %v", p, []byte(message)) - } - return nil // Success. - }) - }) - } - } -} - -func TestWrite(t *testing.T) { - type writeTest struct { - name string - mode p9.OpenFlags - err error - } - - cases := []writeTest{ - { - name: "read-only", - mode: p9.ReadOnly, - err: unix.EPERM, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: nil, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const message = "hello" - - onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, err := f.WriteAt([]byte(message), 0) - return err - } - - // Prepare for the call to readAt in the backend. - var output []byte // Saved by Do below. - backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - output = p - }).Return(len(message), nil) - - // Make the client call. - n, err := f.WriteAt([]byte(message), 0) - - // Sanity check result. - if err != nil { - return err - } - if n != len(message) { - t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) - } - if !bytes.Equal(output, []byte(message)) { - t.Fatalf("message incorrect, got %v, want %v", output, []byte(message)) - } - return nil // Success. - }) - }) - } - } -} - -func TestFSync(t *testing.T) { - for name := range newTypeMap(nil) { - for _, mode := range []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} { - t.Run(fmt.Sprintf("%s-%s", mode, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnOpenThings(h, t, name, root, mode, nil, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().FSync().Times(1) - } - return f.FSync() - }) - }) - } - } -} - -func TestFlush(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - backend.EXPECT().Flush() - f.Flush() - }) - } -} - -// onlyWorksOnDirectories is a helper test method for operations that should -// only work on unopened directories, such as create, mkdir and symlink. -func onlyWorksOnDirectories(h *Harness, t *testing.T, name string, root p9.File, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Only directories support mknod. - if !backend.Attr.Mode.IsDir() { - if err := fn(backend, f, false); err != unix.EINVAL { - t.Errorf("operation got %v, wanted EINVAL", err) - } - return // Nothing else to do. - } - - // Should succeed. - if err := fn(backend, f, true); err != nil { - t.Fatalf("operation got %v, wanted nil", err) - } - - // Open the directory. - backend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := f.Open(p9.ReadOnly); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - - // Should not work again. - if err := fn(backend, f, false); err != unix.EINVAL { - t.Fatalf("operation got %v, wanted EINVAL", err) - } -} - -func TestCreate(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 1, 2) - return err - } - - // If the create is going to succeed, then we - // need to create a new backend file, and we - // clone to ensure that we don't close the - // original. - _, newF, err := f.Walk(nil) - if err != nil { - t.Fatalf("clone got %v, wanted nil", err) - } - defer newF.Close() - newBackend := h.Pop(newF) - - // Run a regular FD test to validate that path. - fdTest(t, func(send *fd.FD) *fd.FD { - // Return the send FD on success. - newFile := h.NewFile()(backend) // New file with the parent backend. - newBackend.EXPECT().Create("new-file", p9.ReadWrite, p9.FileMode(0), p9.UID(1), p9.GID(2)).Return(send, newFile, p9.QID{}, uint32(0), nil) - - // Receive the fd back. - recv, _, _, _, err := newF.Create("new-file", p9.ReadWrite, 0, 1, 2) - if err != nil { - t.Fatalf("create got %v, wanted nil", err) - } - return recv - }) - - // The above will fail via normal test flow, so - // we can assume that it passed. - return nil - }) - }) - } -} - -func TestCreateInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if _, _, _, _, err := root.Create(invalidName, p9.ReadWrite, 0, 0, 0); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestMkdir(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Mkdir("new-directory", p9.FileMode(0), p9.UID(1), p9.GID(2)) - } - _, err := f.Mkdir("new-directory", 0, 1, 2) - return err - }) - }) - } -} - -func TestMkdirInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if _, err := root.Mkdir(invalidName, 0, 0, 0); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestSymlink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Symlink("old-name", "new-name", p9.UID(1), p9.GID(2)) - } - _, err := f.Symlink("old-name", "new-name", 1, 2) - return err - }) - }) - } -} - -func TestSyminkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - // We need only test for invalid names in the new name, - // the target can be an arbitrary string and we don't - // need to sanity check it. - if _, err := root.Symlink("old-name", invalidName, 0, 0); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestLink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Link(gomock.Any(), "new-link") - } - return f.Link(f, "new-link") - }) - }) - } -} - -func TestLinkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.Link(root, invalidName); err != unix.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestMknod(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Mknod("new-block-device", p9.FileMode(0), uint32(1), uint32(2), p9.UID(3), p9.GID(4)).Times(1) - } - _, err := f.Mknod("new-block-device", 0, 1, 2, 3, 4) - return err - }) - }) - } -} - -// concurrentFn is a specification of a concurrent operation. This is used to -// drive the concurrency tests below. -type concurrentFn struct { - name string - match func(p9.FileMode) bool - op func(h *Harness, backend *Mock, f p9.File, callback func()) -} - -func concurrentTest(t *testing.T, name string, fn1, fn2 concurrentFn, sameDir, expectedOkay bool) { - var ( - names1 []string - names2 []string - ) - if sameDir { - // Use the same file one directory up. - names1, names2 = []string{"one", name}, []string{"one", name} - } else { - // For different directories, just use siblings. - names1, names2 = []string{"one", name}, []string{"three", name} - } - - t.Run(fmt.Sprintf("%s(%v)+%s(%v)", fn1.name, names1, fn2.name, names2), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to both files as given. - _, f1, err := root.Walk(names1) - if err != nil { - t.Fatalf("error walking, got %v, want nil", err) - } - defer f1.Close() - b1 := h.Pop(f1) - _, f2, err := root.Walk(names2) - if err != nil { - t.Fatalf("error walking, got %v, want nil", err) - } - defer f2.Close() - b2 := h.Pop(f2) - - // Are these a good match for the current test case? - if !fn1.match(b1.Attr.Mode) { - t.SkipNow() - } - if !fn2.match(b2.Attr.Mode) { - t.SkipNow() - } - - // Construct our "concurrency creator". - in1 := make(chan struct{}, 1) - in2 := make(chan struct{}, 1) - var top sync.WaitGroup - var fns sync.WaitGroup - defer top.Wait() - top.Add(2) // Accounting for below. - defer fns.Done() - fns.Add(1) // See line above; released before top.Wait. - go func() { - defer top.Done() - fn1.op(h, b1, f1, func() { - in1 <- struct{}{} - fns.Wait() - }) - }() - go func() { - defer top.Done() - fn2.op(h, b2, f2, func() { - in2 <- struct{}{} - fns.Wait() - }) - }() - - // Compute a reasonable timeout. If we expect the operation to hang, - // give it 10 milliseconds before we assert that it's fine. After all, - // there will be a lot of these tests. If we don't expect it to hang, - // give it a full minute, since the machine could be slow. - timeout := 10 * time.Millisecond - if expectedOkay { - timeout = 1 * time.Minute - } - - // Read the first channel. - var second chan struct{} - select { - case <-in1: - second = in2 - case <-in2: - second = in1 - } - - // Catch concurrency. - select { - case <-second: - // We finished successful. Is this good? Depends on the - // expected result. - if !expectedOkay { - t.Errorf("%q and %q proceeded concurrently!", fn1.name, fn2.name) - } - case <-time.After(timeout): - // Great, things did not proceed concurrently. Is that what we - // expected? - if expectedOkay { - t.Errorf("%q and %q hung concurrently!", fn1.name, fn2.name) - } - } - }) -} - -func randomFileName() string { - return fmt.Sprintf("%x", rand.Int63()) -} - -func TestConcurrency(t *testing.T) { - readExclusive := []concurrentFn{ - { - // N.B. We can't explicitly check WalkGetAttr behavior, - // but we rely on the fact that the internal code paths - // are the same. - name: "walk", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // See the documentation of WalkCallback. - // Because walk is actually implemented by the - // mock, we need a special place for this - // callback. - // - // Note that a clone actually locks the parent - // node. So we walk from this node to test - // concurrent operations appropriately. - backend.WalkCallback = func() error { - callback() - return nil - } - f.Walk([]string{randomFileName()}) // Won't exist. - }, - }, - { - name: "fsync", - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()) - backend.EXPECT().FSync().Do(func() { - callback() - }) - f.Open(p9.ReadOnly) // Required. - f.FSync() - }, - }, - { - name: "readdir", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()) - backend.EXPECT().Readdir(gomock.Any(), gomock.Any()).Do(func(uint64, uint32) { - callback() - }) - f.Open(p9.ReadOnly) // Required. - f.Readdir(0, 1) - }, - }, - { - name: "readlink", - match: func(mode p9.FileMode) bool { return mode.IsSymlink() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Readlink().Do(func() { - callback() - }) - f.Readlink() - }, - }, - { - name: "connect", - match: func(mode p9.FileMode) bool { return mode.IsSocket() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Connect(gomock.Any()).Do(func(p9.ConnectFlags) { - callback() - }) - f.Connect(0) - }, - }, - { - name: "open", - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()).Do(func(p9.OpenFlags) { - callback() - }) - f.Open(p9.ReadOnly) - }, - }, - { - name: "flush", - match: func(mode p9.FileMode) bool { return true }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Flush().Do(func() { - callback() - }) - f.Flush() - }, - }, - } - writeExclusive := []concurrentFn{ - { - // N.B. We can't really check getattr. But this is an - // extremely low-risk function, it seems likely that - // this check is paranoid anyways. - name: "setattr", - match: func(mode p9.FileMode) bool { return true }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().SetAttr(gomock.Any(), gomock.Any()).Do(func(p9.SetAttrMask, p9.SetAttr) { - callback() - }) - f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}) - }, - }, - { - name: "unlinkAt", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { - callback() - }) - f.UnlinkAt(randomFileName(), 0) - }, - }, - { - name: "mknod", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Mknod(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, uint32, uint32, p9.UID, p9.GID) { - callback() - }) - f.Mknod(randomFileName(), 0, 0, 0, 0, 0) - }, - }, - { - name: "link", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Link(gomock.Any(), gomock.Any()).Do(func(p9.File, string) { - callback() - }) - f.Link(f, randomFileName()) - }, - }, - { - name: "symlink", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Symlink(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, string, p9.UID, p9.GID) { - callback() - }) - f.Symlink(randomFileName(), randomFileName(), 0, 0) - }, - }, - { - name: "mkdir", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Mkdir(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, p9.UID, p9.GID) { - callback() - }) - f.Mkdir(randomFileName(), 0, 0, 0) - }, - }, - { - name: "create", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Return an error for the creation operation, as this is the simplest. - backend.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, p9.QID{}, uint32(0), unix.EINVAL).Do(func(string, p9.OpenFlags, p9.FileMode, p9.UID, p9.GID) { - callback() - }) - f.Create(randomFileName(), p9.ReadOnly, 0, 0, 0) - }, - }, - } - globalExclusive := []concurrentFn{ - { - name: "remove", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Remove operates on a locked parent. So we - // add a child, walk to it and call remove. - // Note that because this operation can operate - // concurrently with itself, we need to - // generate a random file name. - randomFile := randomFileName() - backend.AddChild(randomFile, h.NewFile()) - defer backend.RemoveChild(randomFile) - _, file, err := f.Walk([]string{randomFile}) - if err != nil { - h.t.Fatalf("walk got %v, want nil", err) - } - - // Remove is automatically translated to the parent. - backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { - callback() - }) - - // Remove is also a close. - file.(deprecatedRemover).Remove() - }, - }, - { - name: "rename", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Similarly to remove, because we need to - // operate on a child, we allow a walk. - randomFile := randomFileName() - backend.AddChild(randomFile, h.NewFile()) - defer backend.RemoveChild(randomFile) - _, file, err := f.Walk([]string{randomFile}) - if err != nil { - h.t.Fatalf("walk got %v, want nil", err) - } - defer file.Close() - fileBackend := h.Pop(file) - - // Rename is automatically translated to the parent. - backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { - callback() - }) - - // Attempt the rename. - fileBackend.EXPECT().Renamed(gomock.Any(), gomock.Any()) - file.Rename(f, randomFileName()) - }, - }, - { - name: "renameAt", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { - callback() - }) - - // Attempt the rename. There are no active fids - // with this name, so we don't need to expect - // Renamed hooks on anything. - f.RenameAt(randomFileName(), f, randomFileName()) - }, - }, - } - - for _, fn1 := range readExclusive { - for _, fn2 := range readExclusive { - for name := range newTypeMap(nil) { - // Everything should be able to proceed in parallel. - concurrentTest(t, name, fn1, fn2, true, true) - concurrentTest(t, name, fn1, fn2, false, true) - } - } - } - - for _, fn1 := range append(readExclusive, writeExclusive...) { - for _, fn2 := range writeExclusive { - for name := range newTypeMap(nil) { - // Only cross-directory functions should proceed in parallel. - concurrentTest(t, name, fn1, fn2, true, false) - concurrentTest(t, name, fn1, fn2, false, true) - } - } - } - - for _, fn1 := range append(append(readExclusive, writeExclusive...), globalExclusive...) { - for _, fn2 := range globalExclusive { - for name := range newTypeMap(nil) { - // Nothing should be able to run in parallel. - concurrentTest(t, name, fn1, fn2, true, false) - concurrentTest(t, name, fn1, fn2, false, false) - } - } - } -} - -func TestReadWriteConcurrent(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const ( - instances = 10 - iterations = 10000 - dataSize = 1024 - ) - var ( - dataSets [instances][dataSize]byte - backends [instances]*Mock - files [instances]p9.File - ) - - // Walk to the file normally. - for i := 0; i < instances; i++ { - _, backends[i], files[i] = walkHelper(h, "file", root) - defer files[i].Close() - } - - // Open the files. - for i := 0; i < instances; i++ { - backends[i].EXPECT().Open(p9.ReadWrite) - if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - } - - // Initialize random data for each instance. - for i := 0; i < instances; i++ { - if _, err := rand.Read(dataSets[i][:]); err != nil { - t.Fatalf("error initializing dataSet#%d, got %v", i, err) - } - } - - // Define our random read/write mechanism. - randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { - // Prepare the backend. - backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - if n := copy(p, data); n != len(data) { - // Note that we have to assert the result here, as the Return statement - // below cannot be dynamic: it will be bound before this call is made. - h.t.Errorf("wanted length %d, got %d", len(data), n) - } - }).Return(len(data), nil) - - // Execute the read. - if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { - t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) - return // No sense doing check below. - } - if !bytes.Equal(test, data) { - t.Errorf("data integrity failed during read") // Not as expected. - } - } - randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { - // Prepare the backend. - backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - if !bytes.Equal(p, data) { - h.t.Errorf("data integrity failed during write") // Not as expected. - } - }).Return(len(data), nil) - - // Execute the write. - if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { - t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) - } - } - randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { - test := make([]byte, len(data)) - for i := 0; i < n; i++ { - if rand.Intn(2) == 0 { - randRead(h, backend, f, data, test) - } else { - randWrite(h, backend, f, data) - } - } - } - - // Start reading and writing. - var wg sync.WaitGroup - for i := 0; i < instances; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) - }(i) - } - wg.Wait() -} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go deleted file mode 100644 index fd5ac3dbe..000000000 --- a/pkg/p9/p9test/p9test.go +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package p9test provides standard mocks for p9. -package p9test - -import ( - "fmt" - "sync/atomic" - "testing" - - "github.com/golang/mock/gomock" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/unet" -) - -// Harness is an attacher mock. -type Harness struct { - t *testing.T - mockCtrl *gomock.Controller - Attacher *MockAttacher - wg sync.WaitGroup - clientSocket *unet.Socket - mu sync.Mutex - created []*Mock -} - -// globalPath is a QID.Path Generator. -var globalPath uint64 - -// MakePath returns a globally unique path. -func MakePath() uint64 { - return atomic.AddUint64(&globalPath, 1) -} - -// Generator is a function that generates a new file. -type Generator func(parent *Mock) *Mock - -// Mock is a common mock element. -type Mock struct { - p9.DefaultWalkGetAttr - *MockFile - parent *Mock - closed bool - harness *Harness - QID p9.QID - Attr p9.Attr - children map[string]Generator - - // WalkCallback is a special function that will be called from within - // the walk context. This is needed for the concurrent tests within - // this package. - WalkCallback func() error -} - -// globalMu protects the children maps in all mocks. Note that this is not a -// particularly elegant solution, but because the test has walks from the root -// through to final nodes, we must share maps below, and it's easiest to simply -// protect against concurrent access globally. -var globalMu sync.RWMutex - -// AddChild adds a new child to the Mock. -func (m *Mock) AddChild(name string, generator Generator) { - globalMu.Lock() - defer globalMu.Unlock() - m.children[name] = generator -} - -// RemoveChild removes the child with the given name. -func (m *Mock) RemoveChild(name string) { - globalMu.Lock() - defer globalMu.Unlock() - delete(m.children, name) -} - -// Matches implements gomock.Matcher.Matches. -func (m *Mock) Matches(x interface{}) bool { - if om, ok := x.(*Mock); ok { - return m.QID.Path == om.QID.Path - } - return false -} - -// String implements gomock.Matcher.String. -func (m *Mock) String() string { - return fmt.Sprintf("Mock{Mode: 0x%x, QID.Path: %d}", m.Attr.Mode, m.QID.Path) -} - -// GetAttr returns the current attributes. -func (m *Mock) GetAttr(mask p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - return m.QID, p9.AttrMaskAll(), m.Attr, nil -} - -// Walk supports clone and walking in directories. -func (m *Mock) Walk(names []string) ([]p9.QID, p9.File, error) { - if m.WalkCallback != nil { - if err := m.WalkCallback(); err != nil { - return nil, nil, err - } - } - if len(names) == 0 { - // Clone the file appropriately. - nm := m.harness.NewMock(m.parent, m.QID.Path, m.Attr) - nm.children = m.children // Inherit children. - return []p9.QID{nm.QID}, nm, nil - } else if len(names) != 1 { - m.harness.t.Fail() // Should not happen. - return nil, nil, unix.EINVAL - } - - if m.Attr.Mode.IsDir() { - globalMu.RLock() - defer globalMu.RUnlock() - if fn, ok := m.children[names[0]]; ok { - // Generate the child. - nm := fn(m) - return []p9.QID{nm.QID}, nm, nil - } - // No child found. - return nil, nil, unix.ENOENT - } - - // Call the underlying mock. - return m.MockFile.Walk(names) -} - -// WalkGetAttr calls the default implementation; this is a client-side optimization. -func (m *Mock) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { - return m.DefaultWalkGetAttr.WalkGetAttr(names) -} - -// Pop pops off the most recently created Mock and assert that this mock -// represents the same file passed in. If nil is passed in, no check is -// performed. -// -// Precondition: there must be at least one Mock or this will panic. -func (h *Harness) Pop(clientFile p9.File) *Mock { - h.mu.Lock() - defer h.mu.Unlock() - - if clientFile == nil { - // If no clientFile is provided, then we always return the last - // created file. The caller can safely use this as long as - // there is no concurrency. - m := h.created[len(h.created)-1] - h.created = h.created[:len(h.created)-1] - return m - } - - qid, _, _, err := clientFile.GetAttr(p9.AttrMaskAll()) - if err != nil { - // We do not expect this to happen. - panic(fmt.Sprintf("err during Pop: %v", err)) - } - - // Find the relevant file in our created list. We must scan the last - // from back to front to ensure that we favor the most recently - // generated file. - for i := len(h.created) - 1; i >= 0; i-- { - m := h.created[i] - if qid.Path == m.QID.Path { - // Copy and truncate. - copy(h.created[i:], h.created[i+1:]) - h.created = h.created[:len(h.created)-1] - return m - } - } - - // Unable to find relevant file. - panic(fmt.Sprintf("unable to locate file with QID %+v", qid.Path)) -} - -// NewMock returns a new base file. -func (h *Harness) NewMock(parent *Mock, path uint64, attr p9.Attr) *Mock { - m := &Mock{ - MockFile: NewMockFile(h.mockCtrl), - parent: parent, - harness: h, - QID: p9.QID{ - Type: p9.QIDType((attr.Mode & p9.FileModeMask) >> 12), - Path: path, - }, - Attr: attr, - } - - // Always ensure Close is after the parent's close. Note that this - // can't be done via a straight-forward After call, because the parent - // might change after initial creation. We ensure that this is true at - // close time. - m.EXPECT().Close().Return(nil).Times(1).Do(func() { - if m.parent != nil && m.parent.closed { - h.t.FailNow() - } - // Note that this should not be racy, as this operation should - // be protected by the Times(1) above first. - m.closed = true - }) - - // Remember what was created. - h.mu.Lock() - defer h.mu.Unlock() - h.created = append(h.created, m) - - return m -} - -// NewFile returns a new file mock. -// -// Note that ReadAt and WriteAt must be mocked separately. -func (h *Harness) NewFile() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeRegular}) - } -} - -// NewDirectory returns a new mock directory. -// -// Note that Mkdir, Link, Mknod, RenameAt, UnlinkAt and Readdir must be mocked -// separately. Walk is provided and children may be manipulated via AddChild -// and RemoveChild. After calling Walk remotely, one can use Pop to find the -// corresponding backend mock on the server side. -func (h *Harness) NewDirectory(contents map[string]Generator) Generator { - return func(parent *Mock) *Mock { - m := h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeDirectory}) - m.children = contents // Save contents. - return m - } -} - -// NewSymlink returns a new mock directory. -// -// Note that Readlink must be mocked separately. -func (h *Harness) NewSymlink() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSymlink}) - } -} - -// NewBlockDevice returns a new mock block device. -func (h *Harness) NewBlockDevice() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeBlockDevice}) - } -} - -// NewCharacterDevice returns a new mock character device. -func (h *Harness) NewCharacterDevice() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeCharacterDevice}) - } -} - -// NewNamedPipe returns a new mock named pipe. -func (h *Harness) NewNamedPipe() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeNamedPipe}) - } -} - -// NewSocket returns a new mock socket. -func (h *Harness) NewSocket() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSocket}) - } -} - -// Finish completes all checks and shuts down the server. -func (h *Harness) Finish() { - h.clientSocket.Shutdown() - h.wg.Wait() - h.mockCtrl.Finish() -} - -// NewHarness creates and returns a new test server. -// -// It should always be used as: -// -// h, c := NewHarness(t) -// defer h.Finish() -// -func NewHarness(t *testing.T) (*Harness, *p9.Client) { - // Create the mock. - mockCtrl := gomock.NewController(t) - h := &Harness{ - t: t, - mockCtrl: mockCtrl, - Attacher: NewMockAttacher(mockCtrl), - } - - // Make socket pair. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v wanted nil", err) - } - - // Start the server, synchronized on exit. - server := p9.NewServer(h.Attacher) - h.wg.Add(1) - go func() { - defer h.wg.Done() - server.Handle(serverSocket) - }() - - // Create the client. - client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) - if err != nil { - serverSocket.Close() - clientSocket.Close() - t.Fatalf("new client got %v, expected nil", err) - return nil, nil // Never hit. - } - - // Capture the client socket. - h.clientSocket = clientSocket - return h, client -} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go deleted file mode 100644 index a29f06ddb..000000000 --- a/pkg/p9/transport_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "io/ioutil" - "os" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/unet" -) - -const ( - MsgTypeBadEncode = iota + 252 - MsgTypeBadDecode - MsgTypeUnregistered -) - -func TestSendRecv(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &Tlopen{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - t.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - t.Fatalf("got tag %v expected 1", tag) - } - if _, ok := m.(*Tlopen); !ok { - t.Fatalf("got message %v expected *Tlopen", m) - } -} - -// badDecode overruns on decode. -type badDecode struct{} - -func (*badDecode) decode(b *buffer) { b.markOverrun() } -func (*badDecode) encode(b *buffer) {} -func (*badDecode) Type() MsgType { return MsgTypeBadDecode } -func (*badDecode) String() string { return "badDecode{}" } - -func TestRecvOverrun(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &badDecode{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil { - t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err) - } -} - -// unregistered is not registered on decode. -type unregistered struct{} - -func (*unregistered) decode(b *buffer) {} -func (*unregistered) encode(b *buffer) {} -func (*unregistered) Type() MsgType { return MsgTypeUnregistered } -func (*unregistered) String() string { return "unregistered{}" } - -func TestRecvInvalidType(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &unregistered{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - _, _, err = recv(server, maximumLength, msgRegistry.get) - if _, ok := err.(*ErrInvalidMsgType); !ok { - t.Fatalf("recv got err %v expected ErrInvalidMsgType", err) - } -} - -func TestSendRecvWithFile(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - // Create a tempfile. - osf, err := ioutil.TempFile("", "p9") - if err != nil { - t.Fatalf("tempfile got err %v expected nil", err) - } - os.Remove(osf.Name()) - f, err := fd.NewFromFile(osf) - osf.Close() - if err != nil { - t.Fatalf("unable to create file: %v", err) - } - - rlopen := &Rlopen{} - rlopen.SetFilePayload(f) - if err := send(client, Tag(1), rlopen); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - // Enable withFile. - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - t.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - t.Fatalf("got tag %v expected 1", tag) - } - rlopen, ok := m.(*Rlopen) - if !ok { - t.Fatalf("got m %v expected *Rlopen", m) - } - if rlopen.File == nil { - t.Fatalf("got nil file expected non-nil") - } -} - -func TestRecvClosed(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - client.Close() - - _, _, err = recv(server, maximumLength, msgRegistry.get) - if err == nil { - t.Fatalf("got err nil expected non-nil") - } - if _, ok := err.(ErrSocket); !ok { - t.Fatalf("got err %v expected ErrSocket", err) - } -} - -func TestSendClosed(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - server.Close() - defer client.Close() - - err = send(client, Tag(1), &Tlopen{}) - if err == nil { - t.Fatalf("send got err nil expected non-nil") - } - if _, ok := err.(ErrSocket); !ok { - t.Fatalf("got err %v expected ErrSocket", err) - } -} - -func BenchmarkSendRecv(b *testing.B) { - b.ReportAllocs() - - server, client, err := unet.SocketPair(false) - if err != nil { - b.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - // Exchange Rflush messages since these contain no data and therefore incur - // no additional marshaling overhead. - go func() { - for i := 0; i < b.N; i++ { - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - b.Errorf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - b.Errorf("got tag %v expected 1", tag) - } - if _, ok := m.(*Rflush); !ok { - b.Errorf("got message %T expected *Rflush", m) - } - if err := send(server, Tag(2), &Rflush{}); err != nil { - b.Errorf("send got err %v expected nil", err) - } - } - }() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := send(client, Tag(1), &Rflush{}); err != nil { - b.Errorf("send got err %v expected nil", err) - } - tag, m, err := recv(client, maximumLength, msgRegistry.get) - if err != nil { - b.Errorf("recv got err %v expected nil", err) - } - if tag != Tag(2) { - b.Errorf("got tag %v expected 2", tag) - } - if _, ok := m.(*Rflush); !ok { - b.Errorf("got message %v expected *Rflush", m) - } - } -} - -func init() { - msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} }) -} diff --git a/pkg/p9/version_test.go b/pkg/p9/version_test.go deleted file mode 100644 index 291e8580e..000000000 --- a/pkg/p9/version_test.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestVersionNumberEquivalent(t *testing.T) { - for i := uint32(0); i < 1024; i++ { - str := versionString(i) - version, ok := parseVersion(str) - if !ok { - t.Errorf("#%d: parseVersion(%q) failed, want success", i, str) - continue - } - if i != version { - t.Errorf("#%d: got version %d, want %d", i, i, version) - } - } -} - -func TestVersionStringEquivalent(t *testing.T) { - // There is one case where the version is not equivalent on purpose, - // that is 9P2000.L.Google.0. It is not equivalent because versionString - // must always return the more generic 9P2000.L for legacy servers that - // check for it. See net/9p/client.c. - str := "9P2000.L.Google.0" - version, ok := parseVersion(str) - if !ok { - t.Errorf("parseVersion(%q) failed, want success", str) - } - if got := versionString(version); got != "9P2000.L" { - t.Errorf("versionString(%d) got %q, want %q", version, got, "9P2000.L") - } - - for _, test := range []struct { - versionString string - }{ - { - versionString: "9P2000.L", - }, - { - versionString: "9P2000.L.Google.1", - }, - { - versionString: "9P2000.L.Google.347823894", - }, - } { - version, ok := parseVersion(test.versionString) - if !ok { - t.Errorf("parseVersion(%q) failed, want success", test.versionString) - continue - } - if got := versionString(version); got != test.versionString { - t.Errorf("versionString(%d) got %q, want %q", version, got, test.versionString) - } - } -} - -func TestParseVersion(t *testing.T) { - for _, test := range []struct { - versionString string - expectSuccess bool - expectedVersion uint32 - }{ - { - versionString: "9P", - expectSuccess: false, - }, - { - versionString: "9P.L", - expectSuccess: false, - }, - { - versionString: "9P200.L", - expectSuccess: false, - }, - { - versionString: "9P2000", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.-1", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.3546343826724305832", - expectSuccess: false, - }, - { - versionString: "9P2001.L", - expectSuccess: false, - }, - { - versionString: "9P2000.L", - expectSuccess: true, - expectedVersion: 0, - }, - { - versionString: "9P2000.L.Google.0", - expectSuccess: true, - expectedVersion: 0, - }, - { - versionString: "9P2000.L.Google.1", - expectSuccess: true, - expectedVersion: 1, - }, - } { - version, ok := parseVersion(test.versionString) - if ok != test.expectSuccess { - t.Errorf("parseVersion(%q) got (_, %v), want (_, %v)", test.versionString, ok, test.expectSuccess) - continue - } - if !test.expectSuccess { - continue - } - if version != test.expectedVersion { - t.Errorf("parseVersion(%q) got (%d, _), want (%d, _)", test.versionString, version, test.expectedVersion) - } - } -} - -func BenchmarkParseVersion(b *testing.B) { - for n := 0; n < b.N; n++ { - parseVersion("9P2000.L.Google.1") - } -} diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD deleted file mode 100644 index 7b1c6b75b..000000000 --- a/pkg/pool/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -go_library( - name = "pool", - srcs = [ - "pool.go", - ], - deps = [ - "//pkg/sync", - ], -) - -go_test( - name = "pool_test", - size = "small", - srcs = [ - "pool_test.go", - ], - library = ":pool", -) diff --git a/pkg/pool/pool_state_autogen.go b/pkg/pool/pool_state_autogen.go new file mode 100644 index 000000000..1f4164c00 --- /dev/null +++ b/pkg/pool/pool_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pool diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go deleted file mode 100644 index d928439c1..000000000 --- a/pkg/pool/pool_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pool - -import ( - "testing" -) - -func TestPoolUnique(t *testing.T) { - p := Pool{Start: 1, Limit: 3} - got := make(map[uint64]bool) - - for { - n, ok := p.Get() - if !ok { - break - } - - // Check unique. - if _, ok := got[n]; ok { - t.Errorf("pool spit out %v multiple times", n) - } - - // Record. - got[n] = true - } -} - -func TestExausted(t *testing.T) { - p := Pool{Start: 1, Limit: 500} - for i := 0; i < 499; i++ { - _, ok := p.Get() - if !ok { - t.Fatalf("pool exhausted before 499 items") - } - } - - _, ok := p.Get() - if ok { - t.Errorf("pool not exhausted when it should be") - } -} - -func TestPoolRecycle(t *testing.T) { - p := Pool{Start: 1, Limit: 500} - n1, _ := p.Get() - p.Put(n1) - n2, _ := p.Get() - if n1 != n2 { - t.Errorf("pool not recycling items") - } -} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD deleted file mode 100644 index 2838b5aca..000000000 --- a/pkg/procid/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "procid", - srcs = [ - "procid.go", - "procid_amd64.s", - "procid_arm64.s", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "procid_test", - size = "small", - srcs = [ - "procid_test.go", - ], - library = ":procid", - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "procid_net_test", - size = "small", - srcs = [ - "procid_net_test.go", - "procid_test.go", - ], - library = ":procid", - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/procid/procid_net_test.go b/pkg/procid/procid_net_test.go deleted file mode 100644 index b628e2285..000000000 --- a/pkg/procid/procid_net_test.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procid - -// This file is just to force the inclusion of the "net" package, which will -// make the test binary a cgo one. -import ( - _ "net" -) diff --git a/pkg/procid/procid_state_autogen.go b/pkg/procid/procid_state_autogen.go new file mode 100644 index 000000000..be58e2904 --- /dev/null +++ b/pkg/procid/procid_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package procid diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go deleted file mode 100644 index a08110b35..000000000 --- a/pkg/procid/procid_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procid - -import ( - "os" - "runtime" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sync" -) - -// runOnMain is used to send functions to run on the main (initial) thread. -var runOnMain = make(chan func(), 10) - -func checkProcid(t *testing.T, start *sync.WaitGroup, done *sync.WaitGroup) { - defer done.Done() - - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - start.Done() - start.Wait() - - procID := Current() - tid := unix.Gettid() - - if procID != uint64(tid) { - t.Logf("Bad procid: expected %v, got %v", tid, procID) - t.Fail() - } -} - -func TestProcidInitialized(t *testing.T) { - var start sync.WaitGroup - var done sync.WaitGroup - - count := 100 - start.Add(count + 1) - done.Add(count + 1) - - // Run the check on the main thread. - // - // When cgo is not included, the only case when procid isn't initialized - // is in the main (initial) thread, so we have to test this case - // specifically. - runOnMain <- func() { - checkProcid(t, &start, &done) - } - - // Run the check on a number of different threads. - for i := 0; i < count; i++ { - go checkProcid(t, &start, &done) - } - - done.Wait() -} - -func TestMain(m *testing.M) { - // Make sure we remain at the main (initial) thread. - runtime.LockOSThread() - - // Start running tests in a different goroutine. - go func() { - os.Exit(m.Run()) - }() - - // Execute any functions that have been sent for execution in the main - // thread. - for f := range runOnMain { - f() - } -} diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD deleted file mode 100644 index 80b8ceb02..000000000 --- a/pkg/rand/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "rand", - srcs = [ - "rand.go", - "rand_linux.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/rand/rand_linux_state_autogen.go b/pkg/rand/rand_linux_state_autogen.go new file mode 100644 index 000000000..f727c9314 --- /dev/null +++ b/pkg/rand/rand_linux_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package rand diff --git a/pkg/rand/rand_state_autogen.go b/pkg/rand/rand_state_autogen.go new file mode 100644 index 000000000..4320837d6 --- /dev/null +++ b/pkg/rand/rand_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build !linux +// +build !linux + +package rand diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD deleted file mode 100644 index 9888cce9c..000000000 --- a/pkg/refs/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "weak_ref_list", - out = "weak_ref_list.go", - package = "refs", - prefix = "weakRef", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*WeakRef", - "Linker": "*WeakRef", - }, -) - -go_library( - name = "refs", - srcs = [ - "refcounter.go", - "refcounter_state.go", - "weak_ref_list.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/context", - "//pkg/log", - "//pkg/sync", - ], -) - -go_test( - name = "refs_test", - size = "small", - srcs = ["refcounter_test.go"], - library = ":refs", - deps = [ - "//pkg/context", - "//pkg/sync", - ], -) diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go deleted file mode 100644 index 6d0dd1018..000000000 --- a/pkg/refs/refcounter_test.go +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package refs - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sync" -) - -type testCounter struct { - AtomicRefCount - - // mu protects the boolean below. - mu sync.Mutex - - // destroyed indicates whether this was destroyed. - destroyed bool -} - -func (t *testCounter) DecRef(ctx context.Context) { - t.AtomicRefCount.DecRefWithDestructor(ctx, t.destroy) -} - -func (t *testCounter) destroy(context.Context) { - t.mu.Lock() - defer t.mu.Unlock() - t.destroyed = true -} - -func (t *testCounter) IsDestroyed() bool { - t.mu.Lock() - defer t.mu.Unlock() - return t.destroyed -} - -func newTestCounter() *testCounter { - return &testCounter{destroyed: false} -} - -func TestOneRef(t *testing.T) { - tc := newTestCounter() - tc.DecRef(context.Background()) - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestTwoRefs(t *testing.T) { - tc := newTestCounter() - tc.IncRef() - ctx := context.Background() - tc.DecRef(ctx) - tc.DecRef(ctx) - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestMultiRefs(t *testing.T) { - tc := newTestCounter() - tc.IncRef() - ctx := context.Background() - tc.DecRef(ctx) - - tc.IncRef() - tc.DecRef(ctx) - - tc.DecRef(ctx) - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestWeakRef(t *testing.T) { - tc := newTestCounter() - w := NewWeakRef(tc, nil) - ctx := context.Background() - - // Try resolving. - if x := w.Get(); x == nil { - t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) - } else { - x.DecRef(ctx) - } - - // Try resolving again. - if x := w.Get(); x == nil { - t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) - } else { - x.DecRef(ctx) - } - - // Shouldn't be destroyed yet. (Can't continue if this fails.) - if tc.IsDestroyed() { - t.Fatalf("original object destroyed earlier than expected") - } - - // Drop the original reference. - tc.DecRef(ctx) - - // Assert destroyed. - if !tc.IsDestroyed() { - t.Errorf("original object not destroyed as expected") - } - - // Shouldn't be anything. - if x := w.Get(); x != nil { - t.Errorf("weak reference resolved: expected nil, got %v", x) - } -} - -func TestWeakRefDrop(t *testing.T) { - tc := newTestCounter() - w := NewWeakRef(tc, nil) - ctx := context.Background() - w.Drop(ctx) - - // Just assert the list is empty. - if !tc.weakRefs.Empty() { - t.Errorf("weak reference not dropped") - } - - // Drop the original reference. - tc.DecRef(ctx) -} - -type testWeakRefUser struct { - weakRefGone func() -} - -func (u *testWeakRefUser) WeakRefGone(ctx context.Context) { - u.weakRefGone() -} - -func TestCallback(t *testing.T) { - called := false - tc := newTestCounter() - var w *WeakRef - w = NewWeakRef(tc, &testWeakRefUser{func() { - called = true - - // Check that the weak ref has been zapped. - rc := w.obj.Load().(RefCounter) - if v := reflect.ValueOf(rc); v != reflect.Zero(v.Type()) { - t.Fatalf("Callback called with non-nil ptr") - } - - // Check that we're not holding the mutex by acquiring and - // releasing it. - tc.mu.Lock() - tc.mu.Unlock() - }}) - - // Drop the original reference, this must trigger the callback. - ctx := context.Background() - tc.DecRef(ctx) - - if !called { - t.Fatalf("Callback not called") - } -} diff --git a/pkg/refs/refs_state_autogen.go b/pkg/refs/refs_state_autogen.go new file mode 100644 index 000000000..86a102119 --- /dev/null +++ b/pkg/refs/refs_state_autogen.go @@ -0,0 +1,157 @@ +// automatically generated by stateify. + +package refs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (w *WeakRef) StateTypeName() string { + return "pkg/refs.WeakRef" +} + +func (w *WeakRef) StateFields() []string { + return []string{ + "obj", + "user", + } +} + +func (w *WeakRef) beforeSave() {} + +// +checklocksignore +func (w *WeakRef) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + var objValue savedReference + objValue = w.saveObj() + stateSinkObject.SaveValue(0, objValue) + stateSinkObject.Save(1, &w.user) +} + +func (w *WeakRef) afterLoad() {} + +// +checklocksignore +func (w *WeakRef) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(1, &w.user) + stateSourceObject.LoadValue(0, new(savedReference), func(y interface{}) { w.loadObj(y.(savedReference)) }) +} + +func (r *AtomicRefCount) StateTypeName() string { + return "pkg/refs.AtomicRefCount" +} + +func (r *AtomicRefCount) StateFields() []string { + return []string{ + "refCount", + "name", + "stack", + } +} + +func (r *AtomicRefCount) beforeSave() {} + +// +checklocksignore +func (r *AtomicRefCount) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) + stateSinkObject.Save(1, &r.name) + stateSinkObject.Save(2, &r.stack) +} + +func (r *AtomicRefCount) afterLoad() {} + +// +checklocksignore +func (r *AtomicRefCount) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.Load(1, &r.name) + stateSourceObject.Load(2, &r.stack) +} + +func (s *savedReference) StateTypeName() string { + return "pkg/refs.savedReference" +} + +func (s *savedReference) StateFields() []string { + return []string{ + "obj", + } +} + +func (s *savedReference) beforeSave() {} + +// +checklocksignore +func (s *savedReference) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.obj) +} + +func (s *savedReference) afterLoad() {} + +// +checklocksignore +func (s *savedReference) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.obj) +} + +func (l *weakRefList) StateTypeName() string { + return "pkg/refs.weakRefList" +} + +func (l *weakRefList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *weakRefList) beforeSave() {} + +// +checklocksignore +func (l *weakRefList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *weakRefList) afterLoad() {} + +// +checklocksignore +func (l *weakRefList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *weakRefEntry) StateTypeName() string { + return "pkg/refs.weakRefEntry" +} + +func (e *weakRefEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *weakRefEntry) beforeSave() {} + +// +checklocksignore +func (e *weakRefEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *weakRefEntry) afterLoad() {} + +// +checklocksignore +func (e *weakRefEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*WeakRef)(nil)) + state.Register((*AtomicRefCount)(nil)) + state.Register((*savedReference)(nil)) + state.Register((*weakRefList)(nil)) + state.Register((*weakRefEntry)(nil)) +} diff --git a/pkg/refs/weak_ref_list.go b/pkg/refs/weak_ref_list.go new file mode 100644 index 000000000..920e89eb3 --- /dev/null +++ b/pkg/refs/weak_ref_list.go @@ -0,0 +1,221 @@ +package refs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type weakRefElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (weakRefElementMapper) linkerFor(elem *WeakRef) *WeakRef { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type weakRefList struct { + head *WeakRef + tail *WeakRef +} + +// Reset resets list l to the empty state. +func (l *weakRefList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *weakRefList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *weakRefList) Front() *WeakRef { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *weakRefList) Back() *WeakRef { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *weakRefList) Len() (count int) { + for e := l.Front(); e != nil; e = (weakRefElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *weakRefList) PushFront(e *WeakRef) { + linker := weakRefElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + weakRefElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *weakRefList) PushBack(e *WeakRef) { + linker := weakRefElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + weakRefElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *weakRefList) PushBackList(m *weakRefList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + weakRefElementMapper{}.linkerFor(l.tail).SetNext(m.head) + weakRefElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *weakRefList) InsertAfter(b, e *WeakRef) { + bLinker := weakRefElementMapper{}.linkerFor(b) + eLinker := weakRefElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + weakRefElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *weakRefList) InsertBefore(a, e *WeakRef) { + aLinker := weakRefElementMapper{}.linkerFor(a) + eLinker := weakRefElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + weakRefElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *weakRefList) Remove(e *WeakRef) { + linker := weakRefElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + weakRefElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + weakRefElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type weakRefEntry struct { + next *WeakRef + prev *WeakRef +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *weakRefEntry) Next() *WeakRef { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *weakRefEntry) Prev() *WeakRef { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *weakRefEntry) SetNext(elem *WeakRef) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *weakRefEntry) SetPrev(elem *WeakRef) { + e.prev = elem +} diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD deleted file mode 100644 index 7c1a8c792..000000000 --- a/pkg/refsvfs2/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -# TODO(gvisor.dev/issue/1624): rename this directory/package to "refs" once VFS1 -# is gone and the current refs package can be deleted. -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template") - -package(licenses = ["notice"]) - -go_template( - name = "refs_template", - srcs = [ - "refs_template.go", - ], - opt_consts = [ - "enableLogging", - ], - types = [ - "T", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/refs", - ], -) - -go_library( - name = "refsvfs2", - srcs = [ - "refs.go", - "refs_map.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/context", - "//pkg/log", - "//pkg/refs", - "//pkg/sync", - ], -) diff --git a/pkg/refsvfs2/README.md b/pkg/refsvfs2/README.md deleted file mode 100644 index eca53c282..000000000 --- a/pkg/refsvfs2/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Reference Counting - -Go does not offer a reliable way to couple custom resource management with -object lifetime. As a result, we need to manually implement reference counting -for many objects in gVisor to make sure that resources are acquired and released -appropriately. For example, the filesystem has many reference-counted objects -(file descriptions, dentries, inodes, etc.), and it is important that each -object persists while anything holds a reference on it and is destroyed once all -references are dropped. - -We provide a template in `refs_template.go` that can be applied to most objects -in need of reference counting. It contains a simple `Refs` struct that can be -incremented and decremented, and once the reference count reaches zero, a -destructor can be called. Note that there are some objects (e.g. `gofer.dentry`, -`overlay.dentry`) that should not immediately be destroyed upon reaching zero -references; in these cases, this template cannot be applied. - -# Reference Checking - -Unfortunately, manually keeping track of reference counts is extremely error -prone, and improper accounting can lead to production bugs that are very -difficult to root cause. - -We have several ways of discovering reference count errors in gVisor. Any -attempt to increment/decrement a `Refs` struct with a count of zero will trigger -a sentry panic, since the object should have been destroyed and become -unreachable. This allows us to identify missing increments or extra decrements, -which cause the reference count to be lower than it should be: the count will -reach zero earlier than expected, and the next increment/decrement--which should -be valid--will result in a panic. - -It is trickier to identify extra increments and missing decrements, which cause -the reference count to be higher than expected (i.e. a “reference leak”). -Reference leaks prevent resources from being released properly and can translate -to various issues that are tricky to diagnose, such as memory leaks. The -following section discusses how we implement leak checking. - -## Leak Checking - -When leak checking is enabled, reference-counted objects are added to a global -map when constructed and removed when destroyed. Near the very end of sandbox -execution, once no reference-counted objects should still be reachable, we -report everything left in the map as having leaked. Leak-checking objects -implement the `CheckedObject` interface, which allows us to print informative -warnings for each of the leaked objects. - -Leak checking is provided by `refs_template`, but objects that do not use the -template will also need to implement `CheckedObject` and be manually -registered/unregistered from the map in order to be checked. - -Note that leak checking affects performance and memory usage, so it should only -be enabled in testing environments. - -## Debugging - -Even with the checks described above, it can be difficult to track down the -exact source of a reference counting error. The error may occur far before it is -discovered (for instance, a missing `IncRef` may not be discovered until a -future `DecRef` makes the count negative). To aid in debugging, `refs_template` -provides the `enableLogging` option to log every `IncRef`, `DecRef`, and leak -check registration/unregistration, along with the object address and a call -stack. This allows us to search a log for all of the changes to a particular -object's reference count, which makes it much easier to identify the absent or -extraneous operation(s). The reference-counted objects that do not use -`refs_template` also provide logging, and others defined in the future should do -so as well. diff --git a/pkg/refsvfs2/refsvfs2_state_autogen.go b/pkg/refsvfs2/refsvfs2_state_autogen.go new file mode 100644 index 000000000..ca5fbb104 --- /dev/null +++ b/pkg/refsvfs2/refsvfs2_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package refsvfs2 diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD deleted file mode 100644 index fda6ba601..000000000 --- a/pkg/ring0/BUILD +++ /dev/null @@ -1,86 +0,0 @@ -load("//tools:defs.bzl", "arch_genrule", "go_library") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_template( - name = "defs_amd64", - srcs = [ - "defs.go", - "defs_amd64.go", - "offsets_amd64.go", - "x86.go", - ], - visibility = [":__subpackages__"], -) - -go_template( - name = "defs_arm64", - srcs = [ - "aarch64.go", - "defs.go", - "defs_arm64.go", - "offsets_arm64.go", - ], - visibility = [":__subpackages__"], -) - -go_template_instance( - name = "defs_impl_amd64", - out = "defs_impl_amd64.go", - package = "ring0", - template = ":defs_amd64", -) - -go_template_instance( - name = "defs_impl_arm64", - out = "defs_impl_arm64.go", - package = "ring0", - template = ":defs_arm64", -) - -arch_genrule( - name = "entry_impl_amd64", - srcs = ["entry_amd64.s"], - outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", - tools = ["//pkg/ring0/gen_offsets"], -) - -arch_genrule( - name = "entry_impl_arm64", - srcs = ["entry_arm64.s"], - outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", - tools = ["//pkg/ring0/gen_offsets"], -) - -go_library( - name = "ring0", - srcs = [ - "defs_impl_amd64.go", - "defs_impl_arm64.go", - "entry_amd64.go", - "entry_arm64.go", - "entry_impl_amd64.s", - "entry_impl_arm64.s", - "kernel.go", - "kernel_amd64.go", - "kernel_arm64.go", - "kernel_unsafe.go", - "lib_amd64.go", - "lib_amd64.s", - "lib_arm64.go", - "lib_arm64.s", - "ring0.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/cpuid", - "//pkg/hostarch", - "//pkg/ring0/pagetables", - "//pkg/safecopy", - "//pkg/sentry/arch", - "//pkg/sentry/arch/fpu", - ], -) diff --git a/pkg/ring0/aarch64.go b/pkg/ring0/aarch64.go deleted file mode 100644 index 96c884844..000000000 --- a/pkg/ring0/aarch64.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 - -package ring0 - -// Useful bits. -const ( - _PGD_PGT_BASE = 0x1000 - _PGD_PGT_SIZE = 0x1000 - _PUD_PGT_BASE = 0x2000 - _PUD_PGT_SIZE = 0x1000 - _PMD_PGT_BASE = 0x3000 - _PMD_PGT_SIZE = 0x4000 - _PTE_PGT_BASE = 0x7000 - _PTE_PGT_SIZE = 0x1000 -) - -const ( - // DAIF bits:debug, sError, IRQ, FIQ. - _PSR_D_BIT = 0x00000200 - _PSR_A_BIT = 0x00000100 - _PSR_I_BIT = 0x00000080 - _PSR_F_BIT = 0x00000040 - _PSR_DAIF_SHIFT = 6 - _PSR_DAIF_MASK = 0xf << _PSR_DAIF_SHIFT - - // PSR bits. - _PSR_MODE_EL0t = 0x00000000 - _PSR_MODE_EL1t = 0x00000004 - _PSR_MODE_EL1h = 0x00000005 - _PSR_MODE_MASK = 0x0000000f - - PsrFlagsClear = _PSR_MODE_MASK | _PSR_DAIF_MASK - PsrModeMask = _PSR_MODE_MASK - - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = _PSR_MODE_EL1h | _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT - - // UserFlagsSet are always set in userspace. - UserFlagsSet = _PSR_MODE_EL0t -) - -// Vector is an exception vector. -type Vector uintptr - -// Exception vectors. -const ( - El1InvSync = iota - El1InvIrq - El1InvFiq - El1InvError - - El1Sync - El1Irq - El1Fiq - El1Err - - El0Sync - El0Irq - El0Fiq - El0Err - - El0InvSync - El0InvIrq - El0InvFiq - El0InvErr - - El1SyncDa - El1SyncIa - El1SyncSpPc - El1SyncUndef - El1SyncDbg - El1SyncInv - - El0SyncSVC - El0SyncDa - El0SyncIa - El0SyncFpsimdAcc - El0SyncSveAcc - El0SyncFpsimdExc - El0SyncSys - El0SyncSpPc - El0SyncUndef - El0SyncDbg - El0SyncWfx - El0SyncInv - - El0ErrNMI - El0ErrBounce - - _NR_INTERRUPTS -) - -// System call vectors. -const ( - Syscall Vector = El0SyncSVC - PageFault Vector = El0SyncDa - VirtualizationException Vector = El0ErrBounce -) - -// VirtualAddressBits returns the number bits available for virtual addresses. -func VirtualAddressBits() uint32 { - return 48 -} - -// PhysicalAddressBits returns the number of bits available for physical addresses. -func PhysicalAddressBits() uint32 { - return 40 -} diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go deleted file mode 100644 index b6e2012e8..000000000 --- a/pkg/ring0/defs.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" -) - -// Kernel is a global kernel object. -// -// This contains global state, shared by multiple CPUs. -type Kernel struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables - - KernelArchState -} - -// Hooks are hooks for kernel functions. -type Hooks interface { - // KernelSyscall is called for kernel system calls. - // - // Return from this call will restore registers and return to the kernel: the - // registers must be modified directly. - // - // If this function is not provided, a kernel exception results in halt. - // - // This must be go:nosplit, as this will be on the interrupt stack. - // Closures are permitted, as the pointer to the closure frame is not - // passed on the stack. - KernelSyscall() - - // KernelException handles an exception during kernel execution. - // - // Return from this call will restore registers and return to the kernel: the - // registers must be modified directly. - // - // If this function is not provided, a kernel exception results in halt. - // - // This must be go:nosplit, as this will be on the interrupt stack. - // Closures are permitted, as the pointer to the closure frame is not - // passed on the stack. - KernelException(Vector) -} - -// CPU is the per-CPU struct. -type CPU struct { - // self is a self reference. - // - // This is always guaranteed to be at offset zero. - self *CPU - - // kernel is reference to the kernel that this CPU was initialized - // with. This reference is kept for garbage collection purposes: CPU - // registers may refer to objects within the Kernel object that cannot - // be safely freed. - kernel *Kernel - - // CPUArchState is architecture-specific state. - CPUArchState - - // registers is a set of registers; these may be used on kernel system - // calls and exceptions via the Registers function. - registers arch.Registers - - // hooks are kernel hooks. - hooks Hooks -} - -// Registers returns a modifiable-copy of the kernel registers. -// -// This is explicitly safe to call during KernelException and KernelSyscall. -// -//go:nosplit -func (c *CPU) Registers() *arch.Registers { - return &c.registers -} - -// SwitchOpts are passed to the Switch function. -type SwitchOpts struct { - // Registers are the user register state. - Registers *arch.Registers - - // FloatingPointState is a byte pointer where floating point state is - // saved and restored. - FloatingPointState *fpu.State - - // PageTables are the application page tables. - PageTables *pagetables.PageTables - - // Flush indicates that a TLB flush should be forced on switch. - Flush bool - - // FullRestore indicates that an iret-based restore should be used. - FullRestore bool - - // SwitchArchOpts are architecture-specific options. - SwitchArchOpts -} diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go deleted file mode 100644 index 24f6e4cde..000000000 --- a/pkg/ring0/defs_amd64.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build amd64 -// +build amd64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/hostarch" -) - -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - -// Segment indices and Selectors. -const ( - // Index into GDT array. - _ = iota // Null descriptor first. - _ // Reserved (Linux is kernel 32). - segKcode // Kernel code (64-bit). - segKdata // Kernel data. - segUcode32 // User code (32-bit). - segUdata // User data. - segUcode64 // User code (64-bit). - segTss // Task segment descriptor. - segTssHi // Upper bits for TSS. - segLast // Last segment (terminal, not included). -) - -// Selectors. -const ( - Kcode Selector = segKcode << 3 - Kdata Selector = segKdata << 3 - Ucode32 Selector = (segUcode32 << 3) | 3 - Udata Selector = (segUdata << 3) | 3 - Ucode64 Selector = (segUcode64 << 3) | 3 - Tss Selector = segTss << 3 -) - -// Standard segments. -var ( - UserCodeSegment32 SegmentDescriptor - UserDataSegment SegmentDescriptor - UserCodeSegment64 SegmentDescriptor - KernelCodeSegment SegmentDescriptor - KernelDataSegment SegmentDescriptor -) - -// KernelArchState contains architecture-specific state. -type KernelArchState struct { - // cpuEntries is array of kernelEntry for all cpus. - cpuEntries []kernelEntry - - // globalIDT is our set of interrupt gates. - globalIDT *idt64 -} - -// kernelEntry contains minimal CPU-specific arch state -// that can be mapped at the upper of the address space. -// Malicious APP might steal info from it via CPU bugs. -type kernelEntry struct { - // stack is the stack used for interrupts on this CPU. - stack [256]byte - - // scratch space for temporary usage. - scratch0 uint64 - - // stackTop is the top of the stack. - stackTop uint64 - - // cpuSelf is back reference to CPU. - cpuSelf *CPU - - // kernelCR3 is the cr3 used for sentry kernel. - kernelCR3 uintptr - - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 -} - -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { - // errorCode is the error code from the last exception. - errorCode uintptr - - // errorType indicates the type of error code here, it is always set - // along with the errorCode value above. - // - // It will either by 1, which indicates a user error, or 0 indicating a - // kernel error. If the error code below returns false (kernel error), - // then it cannot provide relevant information about the last - // exception. - errorType uintptr - - *kernelEntry -} - -// ErrorCode returns the last error code. -// -// The returned boolean indicates whether the error code corresponds to the -// last user error or not. If it does not, then fault information must be -// ignored. This is generally the result of a kernel fault while servicing a -// user fault. -// -//go:nosplit -func (c *CPU) ErrorCode() (value uintptr, user bool) { - return c.errorCode, c.errorType != 0 -} - -// ClearErrorCode resets the error code. -// -//go:nosplit -func (c *CPU) ClearErrorCode() { - c.errorCode = 0 // No code. - c.errorType = 1 // User mode. -} - -// SwitchArchOpts are embedded in SwitchOpts. -type SwitchArchOpts struct { - // UserPCID indicates that the application PCID to be used on switch, - // assuming that PCIDs are supported. - // - // Per pagetables_x86.go, a zero PCID implies a flush. - UserPCID uint16 - - // KernelPCID indicates that the kernel PCID to be used on return, - // assuming that PCIDs are supported. - // - // Per pagetables_x86.go, a zero PCID implies a flush. - KernelPCID uint16 -} - -func init() { - KernelCodeSegment.setCode64(0, 0, 0) - KernelDataSegment.setData(0, 0xffffffff, 0) - UserCodeSegment32.setCode64(0, 0, 3) - UserDataSegment.setData(0, 0xffffffff, 3) - UserCodeSegment64.setCode64(0, 0, 3) -} diff --git a/pkg/ring0/defs_arm64.go b/pkg/ring0/defs_arm64.go deleted file mode 100644 index 3e212516f..000000000 --- a/pkg/ring0/defs_arm64.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/hostarch" -) - -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits()) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - -// KernelArchState contains architecture-specific state. -type KernelArchState struct { -} - -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { - // stack is the stack used for interrupts on this CPU. - stack [128]byte - - // errorCode is the error code from the last exception. - errorCode uintptr - - // errorType indicates the type of error code here, it is always set - // along with the errorCode value above. - // - // It will either by 1, which indicates a user error, or 0 indicating a - // kernel error. If the error code below returns false (kernel error), - // then it cannot provide relevant information about the last - // exception. - errorType uintptr - - // faultAddr is the value of far_el1. - faultAddr uintptr - - // el0Fp is the address of application's fpstate. - el0Fp uintptr - - // ttbr0Kvm is the value of ttbr0_el1 for sentry. - ttbr0Kvm uintptr - - // ttbr0App is the value of ttbr0_el1 for applicaton. - ttbr0App uintptr - - // exception vector. - vecCode Vector - - // application context pointer. - appAddr uintptr - - // lazyVFP is the value of cpacr_el1. - lazyVFP uintptr - - // appASID is the asid value of guest application. - appASID uintptr -} - -// ErrorCode returns the last error code. -// -// The returned boolean indicates whether the error code corresponds to the -// last user error or not. If it does not, then fault information must be -// ignored. This is generally the result of a kernel fault while servicing a -// user fault. -// -//go:nosplit -func (c *CPU) ErrorCode() (value uintptr, user bool) { - return c.errorCode, c.errorType != 0 -} - -// ClearErrorCode resets the error code. -// -//go:nosplit -func (c *CPU) ClearErrorCode() { - c.errorCode = 0 // No code. - c.errorType = 1 // User mode. -} - -//go:nosplit -func (c *CPU) GetFaultAddr() (value uintptr) { - return c.faultAddr -} - -//go:nosplit -func (c *CPU) SetTtbr0Kvm(value uintptr) { - c.ttbr0Kvm = value -} - -//go:nosplit -func (c *CPU) SetTtbr0App(value uintptr) { - c.ttbr0App = value -} - -//go:nosplit -func (c *CPU) GetVector() (value Vector) { - return c.vecCode -} - -//go:nosplit -func (c *CPU) SetAppAddr(value uintptr) { - c.appAddr = value -} - -// GetLazyVFP returns the value of cpacr_el1. -//go:nosplit -func (c *CPU) GetLazyVFP() (value uintptr) { - return c.lazyVFP -} - -// SwitchArchOpts are embedded in SwitchOpts. -type SwitchArchOpts struct { - // UserASID indicates that the application ASID to be used on switch, - UserASID uint16 - - // KernelASID indicates that the kernel ASID to be used on return, - KernelASID uint16 -} - -func init() { -} diff --git a/pkg/ring0/defs_impl_amd64.go b/pkg/ring0/defs_impl_amd64.go new file mode 100644 index 000000000..d22b41549 --- /dev/null +++ b/pkg/ring0/defs_impl_amd64.go @@ -0,0 +1,600 @@ +//go:build amd64 && amd64 && (386 || amd64) +// +build amd64 +// +build amd64 +// +build 386 amd64 + +package ring0 + +import ( + "fmt" + "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" + "io" + "reflect" +) + +// Kernel is a global kernel object. +// +// This contains global state, shared by multiple CPUs. +type Kernel struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables + + KernelArchState +} + +// Hooks are hooks for kernel functions. +type Hooks interface { + // KernelSyscall is called for kernel system calls. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelSyscall() + + // KernelException handles an exception during kernel execution. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelException(Vector) +} + +// CPU is the per-CPU struct. +type CPU struct { + // self is a self reference. + // + // This is always guaranteed to be at offset zero. + self *CPU + + // kernel is reference to the kernel that this CPU was initialized + // with. This reference is kept for garbage collection purposes: CPU + // registers may refer to objects within the Kernel object that cannot + // be safely freed. + kernel *Kernel + + // CPUArchState is architecture-specific state. + CPUArchState + + // registers is a set of registers; these may be used on kernel system + // calls and exceptions via the Registers function. + registers arch.Registers + + // hooks are kernel hooks. + hooks Hooks +} + +// Registers returns a modifiable-copy of the kernel registers. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) Registers() *arch.Registers { + return &c.registers +} + +// SwitchOpts are passed to the Switch function. +type SwitchOpts struct { + // Registers are the user register state. + Registers *arch.Registers + + // FloatingPointState is a byte pointer where floating point state is + // saved and restored. + FloatingPointState *fpu.State + + // PageTables are the application page tables. + PageTables *pagetables.PageTables + + // Flush indicates that a TLB flush should be forced on switch. + Flush bool + + // FullRestore indicates that an iret-based restore should be used. + FullRestore bool + + // SwitchArchOpts are architecture-specific options. + SwitchArchOpts +} + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// Segment indices and Selectors. +const ( + // Index into GDT array. + _ = iota // Null descriptor first. + _ // Reserved (Linux is kernel 32). + segKcode // Kernel code (64-bit). + segKdata // Kernel data. + segUcode32 // User code (32-bit). + segUdata // User data. + segUcode64 // User code (64-bit). + segTss // Task segment descriptor. + segTssHi // Upper bits for TSS. + segLast // Last segment (terminal, not included). +) + +// Selectors. +const ( + Kcode Selector = segKcode << 3 + Kdata Selector = segKdata << 3 + Ucode32 Selector = (segUcode32 << 3) | 3 + Udata Selector = (segUdata << 3) | 3 + Ucode64 Selector = (segUcode64 << 3) | 3 + Tss Selector = segTss << 3 +) + +// Standard segments. +var ( + UserCodeSegment32 SegmentDescriptor + UserDataSegment SegmentDescriptor + UserCodeSegment64 SegmentDescriptor + KernelCodeSegment SegmentDescriptor + KernelDataSegment SegmentDescriptor +) + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + // cpuEntries is array of kernelEntry for all cpus. + cpuEntries []kernelEntry + + // globalIDT is our set of interrupt gates. + globalIDT *idt64 +} + +// kernelEntry contains minimal CPU-specific arch state +// that can be mapped at the upper of the address space. +// Malicious APP might steal info from it via CPU bugs. +type kernelEntry struct { + // stack is the stack used for interrupts on this CPU. + stack [256]byte + + // scratch space for temporary usage. + scratch0 uint64 + + // stackTop is the top of the stack. + stackTop uint64 + + // cpuSelf is back reference to CPU. + cpuSelf *CPU + + // kernelCR3 is the cr3 used for sentry kernel. + kernelCR3 uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + *kernelEntry +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 + c.errorType = 1 +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserPCID indicates that the application PCID to be used on switch, + // assuming that PCIDs are supported. + // + // Per pagetables_x86.go, a zero PCID implies a flush. + UserPCID uint16 + + // KernelPCID indicates that the kernel PCID to be used on return, + // assuming that PCIDs are supported. + // + // Per pagetables_x86.go, a zero PCID implies a flush. + KernelPCID uint16 +} + +func init() { + KernelCodeSegment.setCode64(0, 0, 0) + KernelDataSegment.setData(0, 0xffffffff, 0) + UserCodeSegment32.setCode64(0, 0, 3) + UserDataSegment.setData(0, 0xffffffff, 3) + UserCodeSegment64.setCode64(0, 0, 3) +} + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + + e := &kernelEntry{} + fmt.Fprintf(w, "\n// CPU entry offsets.\n") + fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0) + fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) + fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) + fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) + fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) + fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) + fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) + fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) + fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) + fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) + fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) + fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) + fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) + fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) + fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) + fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) + fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) + fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) + fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) + fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) + fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &arch.Registers{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FS_BASE 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_GS_BASE 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) +} + +// Useful bits. +const ( + _CR0_PE = 1 << 0 + _CR0_ET = 1 << 4 + _CR0_NE = 1 << 5 + _CR0_AM = 1 << 18 + _CR0_PG = 1 << 31 + + _CR4_PSE = 1 << 4 + _CR4_PAE = 1 << 5 + _CR4_PGE = 1 << 7 + _CR4_OSFXSR = 1 << 9 + _CR4_OSXMMEXCPT = 1 << 10 + _CR4_FSGSBASE = 1 << 16 + _CR4_PCIDE = 1 << 17 + _CR4_OSXSAVE = 1 << 18 + _CR4_SMEP = 1 << 20 + + _RFLAGS_AC = 1 << 18 + _RFLAGS_NT = 1 << 14 + _RFLAGS_IOPL0 = 1 << 12 + _RFLAGS_IOPL1 = 1 << 13 + _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1 + _RFLAGS_DF = 1 << 10 + _RFLAGS_IF = 1 << 9 + _RFLAGS_STEP = 1 << 8 + _RFLAGS_RESERVED = 1 << 1 + + _EFER_SCE = 0x001 + _EFER_LME = 0x100 + _EFER_LMA = 0x400 + _EFER_NX = 0x800 + + _MSR_STAR = 0xc0000081 + _MSR_LSTAR = 0xc0000082 + _MSR_CSTAR = 0xc0000083 + _MSR_SYSCALL_MASK = 0xc0000084 + _MSR_PLATFORM_INFO = 0xce + _MSR_MISC_FEATURES = 0x140 + + _PLATFORM_INFO_CPUID_FAULT = 1 << 31 + + _MISC_FEATURE_CPUID_TRAP = 0x1 +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _RFLAGS_RESERVED + + // UserFlagsSet are always set in userspace. + // + // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege + // level. The Current Privilege Level (CPL) of the task must be less + // than or equal to the IOPL in order for the task or program to access + // I/O ports. + // + // Here, _RFLAGS_IOPL0 is used only to determine whether the task is + // running in the kernel or userspace mode. In the user mode, the CPL is + // always 3 and it doesn't matter what IOPL is set if it is bellow CPL. + // + // We need to have one bit which will be always different in user and + // kernel modes. And we have to remember that even though we have + // KernelFlagsClear, we still can see some of these flags in the kernel + // mode. This can happen when the goruntime switches on a goroutine + // which has been saved in the host mode. On restore, the popf + // instruction is used to restore flags and this means that all flags + // what the goroutine has in the host mode will be restored in the + // kernel mode. + // + // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set + // it in the user mode. So if this flag is set, the task is running in + // the user mode and if it isn't set, the task is running in the kernel + // mode. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0 + + // KernelFlagsClear should always be clear in the kernel. + KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT + + // UserFlagsClear are always cleared in userspace. + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1 +) + +// IsKernelFlags returns true if rflags coresponds to the kernel mode. +// +// go:nosplit +func IsKernelFlags(rflags uint64) bool { + return rflags&_RFLAGS_IOPL0 == 0 +} + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + DivideByZero Vector = iota + Debug + NMI + Breakpoint + Overflow + BoundRangeExceeded + InvalidOpcode + DeviceNotAvailable + DoubleFault + CoprocessorSegmentOverrun + InvalidTSS + SegmentNotPresent + StackSegmentFault + GeneralProtectionFault + PageFault + _ + X87FloatingPointException + AlignmentCheck + MachineCheck + SIMDFloatingPointException + VirtualizationException + SecurityException = 0x1e + SyscallInt80 = 0x80 + _NR_INTERRUPTS = 0x100 +) + +// System call vectors. +const ( + Syscall Vector = _NR_INTERRUPTS +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +// +// Note that sign-extension semantics apply to the highest order bit. +// +// FIXME(b/69382326): This should use the cpuid passed to Init. +func VirtualAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return (ax >> 8) & 0xff +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +// +// FIXME(b/69382326): This should use the cpuid passed to Init. +func PhysicalAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return ax & 0xff +} + +// Selector is a segment Selector. +type Selector uint16 + +// SegmentDescriptor is a segment descriptor. +type SegmentDescriptor struct { + bits [2]uint32 +} + +// descriptorTable is a collection of descriptors. +type descriptorTable [32]SegmentDescriptor + +// SegmentDescriptorFlags are typed flags within a descriptor. +type SegmentDescriptorFlags uint32 + +// SegmentDescriptorFlag declarations. +const ( + SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). + SegmentDescriptorWrite = 1 << 9 // Write permission. + SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. + SegmentDescriptorExecute = 1 << 11 // Execute permission. + SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. + SegmentDescriptorPresent = 1 << 15 // Present. + SegmentDescriptorAVL = 1 << 20 // Available. + SegmentDescriptorLong = 1 << 21 // Long mode. + SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. + SegmentDescriptorG = 1 << 23 // Granularity: page or byte. +) + +// Base returns the descriptor's base linear address. +func (d *SegmentDescriptor) Base() uint32 { + return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 +} + +// Limit returns the descriptor size. +func (d *SegmentDescriptor) Limit() uint32 { + l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 + if d.bits[1]&uint32(SegmentDescriptorG) != 0 { + l <<= 12 + l |= 0xFFF + } + return l +} + +// Flags returns descriptor flags. +func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { + return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) +} + +// DPL returns the descriptor privilege level. +func (d *SegmentDescriptor) DPL() int { + return int((d.bits[1] >> 13) & 3) +} + +func (d *SegmentDescriptor) setNull() { + d.bits[0] = 0 + d.bits[1] = 0 +} + +func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { + flags |= SegmentDescriptorPresent + if limit>>12 != 0 { + limit >>= 12 + flags |= SegmentDescriptorG + } + d.bits[0] = base<<16 | limit&0xFFFF + d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 +} + +func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorDB| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorG| + SegmentDescriptorLong| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorWrite| + SegmentDescriptorSystem) +} + +// setHi is only used for the TSS segment, which is magically 64-bits. +func (d *SegmentDescriptor) setHi(base uint32) { + d.bits[0] = base + d.bits[1] = 0 +} + +// Gate64 is a 64-bit task, trap, or interrupt gate. +type Gate64 struct { + bits [4]uint32 +} + +// idt64 is a 64-bit interrupt descriptor table. +type idt64 [_NR_INTERRUPTS]Gate64 + +func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { + g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF + g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 + g.bits[2] = uint32(rip >> 32) +} + +func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { + g.setInterrupt(cs, rip, dpl, ist) + g.bits[1] |= 1 << 8 +} + +// TaskState64 is a 64-bit task state structure. +type TaskState64 struct { + _ uint32 + rsp0Lo, rsp0Hi uint32 + rsp1Lo, rsp1Hi uint32 + rsp2Lo, rsp2Hi uint32 + _ [2]uint32 + ist1Lo, ist1Hi uint32 + ist2Lo, ist2Hi uint32 + ist3Lo, ist3Hi uint32 + ist4Lo, ist4Hi uint32 + ist5Lo, ist5Hi uint32 + ist6Lo, ist6Hi uint32 + ist7Lo, ist7Hi uint32 + _ [2]uint32 + _ uint16 + ioPerm uint16 +} diff --git a/pkg/ring0/offsets_arm64.go b/pkg/ring0/defs_impl_arm64.go index 60b2c4074..c3c543c88 100644 --- a/pkg/ring0/offsets_arm64.go +++ b/pkg/ring0/defs_impl_arm64.go @@ -1,30 +1,335 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 package ring0 import ( "fmt" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "io" "reflect" +) - "gvisor.dev/gvisor/pkg/sentry/arch" +// Useful bits. +const ( + _PGD_PGT_BASE = 0x1000 + _PGD_PGT_SIZE = 0x1000 + _PUD_PGT_BASE = 0x2000 + _PUD_PGT_SIZE = 0x1000 + _PMD_PGT_BASE = 0x3000 + _PMD_PGT_SIZE = 0x4000 + _PTE_PGT_BASE = 0x7000 + _PTE_PGT_SIZE = 0x1000 +) + +const ( + // DAIF bits:debug, sError, IRQ, FIQ. + _PSR_D_BIT = 0x00000200 + _PSR_A_BIT = 0x00000100 + _PSR_I_BIT = 0x00000080 + _PSR_F_BIT = 0x00000040 + _PSR_DAIF_SHIFT = 6 + _PSR_DAIF_MASK = 0xf << _PSR_DAIF_SHIFT + + // PSR bits. + _PSR_MODE_EL0t = 0x00000000 + _PSR_MODE_EL1t = 0x00000004 + _PSR_MODE_EL1h = 0x00000005 + _PSR_MODE_MASK = 0x0000000f + + PsrFlagsClear = _PSR_MODE_MASK | _PSR_DAIF_MASK + PsrModeMask = _PSR_MODE_MASK + + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _PSR_MODE_EL1h | _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _PSR_MODE_EL0t +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + El1InvSync = iota + El1InvIrq + El1InvFiq + El1InvError + + El1Sync + El1Irq + El1Fiq + El1Err + + El0Sync + El0Irq + El0Fiq + El0Err + + El0InvSync + El0InvIrq + El0InvFiq + El0InvErr + + El1SyncDa + El1SyncIa + El1SyncSpPc + El1SyncUndef + El1SyncDbg + El1SyncInv + + El0SyncSVC + El0SyncDa + El0SyncIa + El0SyncFpsimdAcc + El0SyncSveAcc + El0SyncFpsimdExc + El0SyncSys + El0SyncSpPc + El0SyncUndef + El0SyncDbg + El0SyncWfx + El0SyncInv + + El0ErrNMI + El0ErrBounce + + _NR_INTERRUPTS +) + +// System call vectors. +const ( + Syscall Vector = El0SyncSVC + PageFault Vector = El0SyncDa + VirtualizationException Vector = El0ErrBounce ) +// VirtualAddressBits returns the number bits available for virtual addresses. +func VirtualAddressBits() uint32 { + return 48 +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +func PhysicalAddressBits() uint32 { + return 40 +} + +// Kernel is a global kernel object. +// +// This contains global state, shared by multiple CPUs. +type Kernel struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables + + KernelArchState +} + +// Hooks are hooks for kernel functions. +type Hooks interface { + // KernelSyscall is called for kernel system calls. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelSyscall() + + // KernelException handles an exception during kernel execution. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelException(Vector) +} + +// CPU is the per-CPU struct. +type CPU struct { + // self is a self reference. + // + // This is always guaranteed to be at offset zero. + self *CPU + + // kernel is reference to the kernel that this CPU was initialized + // with. This reference is kept for garbage collection purposes: CPU + // registers may refer to objects within the Kernel object that cannot + // be safely freed. + kernel *Kernel + + // CPUArchState is architecture-specific state. + CPUArchState + + // registers is a set of registers; these may be used on kernel system + // calls and exceptions via the Registers function. + registers arch.Registers + + // hooks are kernel hooks. + hooks Hooks +} + +// Registers returns a modifiable-copy of the kernel registers. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) Registers() *arch.Registers { + return &c.registers +} + +// SwitchOpts are passed to the Switch function. +type SwitchOpts struct { + // Registers are the user register state. + Registers *arch.Registers + + // FloatingPointState is a byte pointer where floating point state is + // saved and restored. + FloatingPointState *fpu.State + + // PageTables are the application page tables. + PageTables *pagetables.PageTables + + // Flush indicates that a TLB flush should be forced on switch. + Flush bool + + // FullRestore indicates that an iret-based restore should be used. + FullRestore bool + + // SwitchArchOpts are architecture-specific options. + SwitchArchOpts +} + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits()) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [128]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // faultAddr is the value of far_el1. + faultAddr uintptr + + // el0Fp is the address of application's fpstate. + el0Fp uintptr + + // ttbr0Kvm is the value of ttbr0_el1 for sentry. + ttbr0Kvm uintptr + + // ttbr0App is the value of ttbr0_el1 for applicaton. + ttbr0App uintptr + + // exception vector. + vecCode Vector + + // application context pointer. + appAddr uintptr + + // lazyVFP is the value of cpacr_el1. + lazyVFP uintptr + + // appASID is the asid value of guest application. + appASID uintptr +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 + c.errorType = 1 +} + +//go:nosplit +func (c *CPU) GetFaultAddr() (value uintptr) { + return c.faultAddr +} + +//go:nosplit +func (c *CPU) SetTtbr0Kvm(value uintptr) { + c.ttbr0Kvm = value +} + +//go:nosplit +func (c *CPU) SetTtbr0App(value uintptr) { + c.ttbr0App = value +} + +//go:nosplit +func (c *CPU) GetVector() (value Vector) { + return c.vecCode +} + +//go:nosplit +func (c *CPU) SetAppAddr(value uintptr) { + c.appAddr = value +} + +// GetLazyVFP returns the value of cpacr_el1. +//go:nosplit +func (c *CPU) GetLazyVFP() (value uintptr) { + return c.lazyVFP +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserASID indicates that the application ASID to be used on switch, + UserASID uint16 + + // KernelASID indicates that the kernel ASID to be used on return, + KernelASID uint16 +} + +func init() { +} + // Emit prints architecture-specific offsets. func Emit(w io.Writer) { fmt.Fprintf(w, "// Automatically generated, do not edit.\n") diff --git a/pkg/ring0/entry_amd64.s b/pkg/ring0/entry_impl_amd64.s index 520bd9f57..1d0262a18 100644 --- a/pkg/ring0/entry_amd64.s +++ b/pkg/ring0/entry_impl_amd64.s @@ -1,3 +1,73 @@ +// build +amd64 + +// Automatically generated, do not edit. + +// CPU offsets. +#define CPU_REGISTERS 0x28 +#define CPU_ERROR_CODE 0x10 +#define CPU_ERROR_TYPE 0x18 +#define CPU_ENTRY 0x20 + +// CPU entry offsets. +#define ENTRY_SCRATCH0 0x100 +#define ENTRY_STACK_TOP 0x108 +#define ENTRY_CPU_SELF 0x110 +#define ENTRY_KERNEL_CR3 0x118 + +// Bits. +#define _RFLAGS_IF 0x200 +#define _RFLAGS_IOPL0 0x1000 +#define _KERNEL_FLAGS 0x02 + +// Vectors. +#define DivideByZero 0x00 +#define Debug 0x01 +#define NMI 0x02 +#define Breakpoint 0x03 +#define Overflow 0x04 +#define BoundRangeExceeded 0x05 +#define InvalidOpcode 0x06 +#define DeviceNotAvailable 0x07 +#define DoubleFault 0x08 +#define CoprocessorSegmentOverrun 0x09 +#define InvalidTSS 0x0a +#define SegmentNotPresent 0x0b +#define StackSegmentFault 0x0c +#define GeneralProtectionFault 0x0d +#define PageFault 0x0e +#define X87FloatingPointException 0x10 +#define AlignmentCheck 0x11 +#define MachineCheck 0x12 +#define SIMDFloatingPointException 0x13 +#define VirtualizationException 0x14 +#define SecurityException 0x1e +#define SyscallInt80 0x80 +#define Syscall 0x100 + +// Ptrace registers. +#define PTRACE_R15 0x00 +#define PTRACE_R14 0x08 +#define PTRACE_R13 0x10 +#define PTRACE_R12 0x18 +#define PTRACE_RBP 0x20 +#define PTRACE_RBX 0x28 +#define PTRACE_R11 0x30 +#define PTRACE_R10 0x38 +#define PTRACE_R9 0x40 +#define PTRACE_R8 0x48 +#define PTRACE_RAX 0x50 +#define PTRACE_RCX 0x58 +#define PTRACE_RDX 0x60 +#define PTRACE_RSI 0x68 +#define PTRACE_RDI 0x70 +#define PTRACE_ORIGRAX 0x78 +#define PTRACE_RIP 0x80 +#define PTRACE_CS 0x88 +#define PTRACE_FLAGS 0x90 +#define PTRACE_RSP 0x98 +#define PTRACE_SS 0xa0 +#define PTRACE_FS_BASE 0xa8 +#define PTRACE_GS_BASE 0xb0 // Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/ring0/entry_arm64.s b/pkg/ring0/entry_impl_arm64.s index f801b8e11..9cc09524a 100644 --- a/pkg/ring0/entry_arm64.s +++ b/pkg/ring0/entry_impl_arm64.s @@ -1,3 +1,93 @@ +// build +arm64 + +// Automatically generated, do not edit. + +// CPU offsets. +#define CPU_SELF 0x00 +#define CPU_REGISTERS 0xe0 +#define CPU_STACK_TOP 0x90 +#define CPU_ERROR_CODE 0x90 +#define CPU_ERROR_TYPE 0x98 +#define CPU_FAULT_ADDR 0xa0 +#define CPU_FPSTATE_EL0 0xa8 +#define CPU_TTBR0_KVM 0xb0 +#define CPU_TTBR0_APP 0xb8 +#define CPU_VECTOR_CODE 0xc0 +#define CPU_APP_ADDR 0xc8 +#define CPU_LAZY_VFP 0xd0 +#define CPU_APP_ASID 0xd8 + +// Bits. +#define _KERNEL_FLAGS 0x3c5 + +// Vectors. +#define El1Sync 0x04 +#define El1Irq 0x05 +#define El1Fiq 0x06 +#define El1Err 0x07 +#define El0Sync 0x08 +#define El0Irq 0x09 +#define El0Fiq 0x0a +#define El0Err 0x0b +#define El1SyncDa 0x10 +#define El1SyncIa 0x11 +#define El1SyncSpPc 0x12 +#define El1SyncUndef 0x13 +#define El1SyncDbg 0x14 +#define El1SyncInv 0x15 +#define El0SyncSVC 0x16 +#define El0SyncDa 0x17 +#define El0SyncIa 0x18 +#define El0SyncFpsimdAcc 0x19 +#define El0SyncSveAcc 0x1a +#define El0SyncFpsimdExc 0x1b +#define El0SyncSys 0x1c +#define El0SyncSpPc 0x1d +#define El0SyncUndef 0x1e +#define El0SyncDbg 0x1f +#define El0SyncWfx 0x20 +#define El0SyncInv 0x21 +#define El0ErrNMI 0x22 +#define PageFault 0x17 +#define Syscall 0x16 +#define VirtualizationException 0x23 + +// Ptrace registers. +#define PTRACE_R0 0x00 +#define PTRACE_R1 0x08 +#define PTRACE_R2 0x10 +#define PTRACE_R3 0x18 +#define PTRACE_R4 0x20 +#define PTRACE_R5 0x28 +#define PTRACE_R6 0x30 +#define PTRACE_R7 0x38 +#define PTRACE_R8 0x40 +#define PTRACE_R9 0x48 +#define PTRACE_R10 0x50 +#define PTRACE_R11 0x58 +#define PTRACE_R12 0x60 +#define PTRACE_R13 0x68 +#define PTRACE_R14 0x70 +#define PTRACE_R15 0x78 +#define PTRACE_R16 0x80 +#define PTRACE_R17 0x88 +#define PTRACE_R18 0x90 +#define PTRACE_R19 0x98 +#define PTRACE_R20 0xa0 +#define PTRACE_R21 0xa8 +#define PTRACE_R22 0xb0 +#define PTRACE_R23 0xb8 +#define PTRACE_R24 0xc0 +#define PTRACE_R25 0xc8 +#define PTRACE_R26 0xd0 +#define PTRACE_R27 0xd8 +#define PTRACE_R28 0xe0 +#define PTRACE_R29 0xe8 +#define PTRACE_R30 0xf0 +#define PTRACE_SP 0xf8 +#define PTRACE_PC 0x100 +#define PTRACE_PSTATE 0x108 +#define PTRACE_TLS 0x110 // Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD deleted file mode 100644 index 9ea8f9a4f..000000000 --- a/pkg/ring0/gen_offsets/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_binary") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "defs_impl_arm64", - out = "defs_impl_arm64.go", - package = "main", - template = "//pkg/ring0:defs_arm64", -) - -go_template_instance( - name = "defs_impl_amd64", - out = "defs_impl_amd64.go", - package = "main", - template = "//pkg/ring0:defs_amd64", -) - -go_binary( - name = "gen_offsets", - srcs = [ - "defs_impl_amd64.go", - "defs_impl_arm64.go", - "main.go", - ], - # Use the libc malloc to avoid any extra dependencies. This is required to - # pass the sentry deps test. - system_malloc = True, - visibility = [ - "//pkg/ring0:__pkg__", - "//pkg/sentry/platform/kvm:__pkg__", - ], - deps = [ - "//pkg/cpuid", - "//pkg/hostarch", - "//pkg/ring0/pagetables", - "//pkg/sentry/arch", - "//pkg/sentry/arch/fpu", - ], -) diff --git a/pkg/ring0/gen_offsets/main.go b/pkg/ring0/gen_offsets/main.go deleted file mode 100644 index a4927da2f..000000000 --- a/pkg/ring0/gen_offsets/main.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary gen_offsets is a helper for generating offset headers. -package main - -import ( - "os" -) - -func main() { - Emit(os.Stdout) -} diff --git a/pkg/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go deleted file mode 100644 index 75f6218b3..000000000 --- a/pkg/ring0/offsets_amd64.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build amd64 -// +build amd64 - -package ring0 - -import ( - "fmt" - "io" - "reflect" - - "gvisor.dev/gvisor/pkg/sentry/arch" -) - -// Emit prints architecture-specific offsets. -func Emit(w io.Writer) { - fmt.Fprintf(w, "// Automatically generated, do not edit.\n") - - c := &CPU{} - fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) - - e := &kernelEntry{} - fmt.Fprintf(w, "\n// CPU entry offsets.\n") - fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer()) - fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer()) - fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer()) - fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer()) - - fmt.Fprintf(w, "\n// Bits.\n") - fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) - fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0) - fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) - - fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) - fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) - fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) - fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) - fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) - fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) - fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) - fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) - fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) - fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) - fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) - fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) - fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) - fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) - fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) - fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) - fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) - fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) - fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) - fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) - fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) - fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) - fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) - - p := &arch.Registers{} - fmt.Fprintf(w, "\n// Ptrace registers.\n") - fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_FS_BASE 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_GS_BASE 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) -} diff --git a/pkg/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD deleted file mode 100644 index f855f4d42..000000000 --- a/pkg/ring0/pagetables/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -[ - # These files are tagged with relevant build architectures. We can always - # build all the input files, which will be included only in the relevant - # architecture builds. - go_template( - name = "generic_walker_%s" % arch, - srcs = [ - "walker_generic.go", - "walker_%s.go" % arch, - ], - opt_types = [ - "Visitor", - ], - visibility = [":__pkg__"], - ) - for arch in ("amd64", "arm64") -] - -[ - # See above. - go_template_instance( - name = "walker_%s_%s" % (op, arch), - out = "walker_%s_%s.go" % (op, arch), - package = "pagetables", - prefix = op, - template = ":generic_walker_%s" % arch, - types = { - "Visitor": "%sVisitor" % op, - }, - ) - for op in ("map", "unmap", "lookup", "empty", "check") - for arch in ("amd64", "arm64") -] - -go_library( - name = "pagetables", - srcs = [ - "allocator.go", - "allocator_unsafe.go", - "pagetables.go", - "pagetables_aarch64.go", - "pagetables_amd64.go", - "pagetables_arm64.go", - "pagetables_x86.go", - "pcids.go", - "pcids_aarch64.go", - "pcids_aarch64.s", - "pcids_x86.go", - "walker_amd64.go", - "walker_arm64.go", - "walker_generic.go", - ":walker_empty_amd64", - ":walker_empty_arm64", - ":walker_lookup_amd64", - ":walker_lookup_arm64", - ":walker_map_amd64", - ":walker_map_arm64", - ":walker_unmap_amd64", - ":walker_unmap_arm64", - ], - visibility = [ - "//pkg/ring0:__subpackages__", - "//pkg/sentry/platform/kvm:__subpackages__", - ], - deps = [ - "//pkg/hostarch", - "//pkg/sync", - ], -) - -go_test( - name = "pagetables_test", - size = "small", - srcs = [ - "pagetables_amd64_test.go", - "pagetables_arm64_test.go", - "pagetables_test.go", - ":walker_check_amd64", - ":walker_check_arm64", - ], - library = ":pagetables", - deps = ["//pkg/hostarch"], -) diff --git a/pkg/ring0/pagetables/pagetables_aarch64_state_autogen.go b/pkg/ring0/pagetables/pagetables_aarch64_state_autogen.go new file mode 100644 index 000000000..e395f6e5e --- /dev/null +++ b/pkg/ring0/pagetables/pagetables_aarch64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 +// +build arm64,arm64 + +package pagetables diff --git a/pkg/ring0/pagetables/pagetables_amd64_state_autogen.go b/pkg/ring0/pagetables/pagetables_amd64_state_autogen.go new file mode 100644 index 000000000..50f7f3fb1 --- /dev/null +++ b/pkg/ring0/pagetables/pagetables_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package pagetables diff --git a/pkg/ring0/pagetables/pagetables_amd64_test.go b/pkg/ring0/pagetables/pagetables_amd64_test.go deleted file mode 100644 index c27b3b10a..000000000 --- a/pkg/ring0/pagetables/pagetables_amd64_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build amd64 -// +build amd64 - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/hostarch" -) - -func Test2MAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: hostarch.Read}, pmdSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, - {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read}}, - }) -} - -func Test1GAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: hostarch.Read}, pudSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, - {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read}}, - }) -} - -func TestSplit1GPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a super page and knock out the middle. - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: hostarch.Read}, pudSize*42) - pt.Unmap(hostarch.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read}}, - {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read}}, - }) -} - -func TestSplit2MPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a huge page and knock out the middle. - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: hostarch.Read}, pmdSize*42) - pt.Unmap(hostarch.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read}}, - {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read}}, - }) -} diff --git a/pkg/ring0/pagetables/pagetables_arm64_state_autogen.go b/pkg/ring0/pagetables/pagetables_arm64_state_autogen.go new file mode 100644 index 000000000..83bdb8f3c --- /dev/null +++ b/pkg/ring0/pagetables/pagetables_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package pagetables diff --git a/pkg/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go deleted file mode 100644 index 1c919ec7d..000000000 --- a/pkg/ring0/pagetables/pagetables_arm64_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/hostarch" -) - -func Test2MAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*47) - - pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: false}, pteSize*42) - pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: false}, pmdSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}}, - {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: true}}, - {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: false}}, - {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: false}}, - }) -} - -func Test1GAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}}, - {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read, User: true}}, - }) -} - -func TestSplit1GPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a super page and knock out the middle. - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*42) - pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read, User: true}}, - {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}}, - }) -} - -func TestSplit2MPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a huge page and knock out the middle. - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*42) - pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read, User: true}}, - {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}}, - }) -} diff --git a/pkg/ring0/pagetables/pagetables_state_autogen.go b/pkg/ring0/pagetables/pagetables_state_autogen.go new file mode 100644 index 000000000..4c4540603 --- /dev/null +++ b/pkg/ring0/pagetables/pagetables_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pagetables diff --git a/pkg/ring0/pagetables/pagetables_test.go b/pkg/ring0/pagetables/pagetables_test.go deleted file mode 100644 index df93dcb6a..000000000 --- a/pkg/ring0/pagetables/pagetables_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagetables - -import ( - "gvisor.dev/gvisor/pkg/hostarch" - "testing" -) - -type mapping struct { - start uintptr - length uintptr - addr uintptr - opts MapOpts -} - -type checkVisitor struct { - expected []mapping // Input. - current int // Temporary. - found []mapping // Output. - failed string // Output. -} - -func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { - v.found = append(v.found, mapping{ - start: start, - length: align + 1, - addr: pte.Address(), - opts: pte.Opts(), - }) - if v.failed != "" { - // Don't keep looking for errors. - return false - } - - if v.current >= len(v.expected) { - v.failed = "more mappings than expected" - } else if v.expected[v.current].start != start { - v.failed = "start didn't match expected" - } else if v.expected[v.current].length != (align + 1) { - v.failed = "end didn't match expected" - } else if v.expected[v.current].addr != pte.Address() { - v.failed = "address didn't match expected" - } else if v.expected[v.current].opts != pte.Opts() { - v.failed = "opts didn't match" - } - v.current++ - return true -} - -func (*checkVisitor) requiresAlloc() bool { return false } - -func (*checkVisitor) requiresSplit() bool { return false } - -func checkMappings(t *testing.T, pt *PageTables, m []mapping) { - // Iterate over all the mappings. - w := checkWalker{ - pageTables: pt, - visitor: checkVisitor{ - expected: m, - }, - } - w.iterateRange(0, ^uintptr(0)) - - // Were we expected additional mappings? - if w.visitor.failed == "" && w.visitor.current != len(w.visitor.expected) { - w.visitor.failed = "insufficient mappings found" - } - - // Emit a meaningful error message on failure. - if w.visitor.failed != "" { - t.Errorf("%s; got %#v, wanted %#v", w.visitor.failed, w.visitor.found, w.visitor.expected) - } -} - -func TestUnmap(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map and unmap one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) - pt.Unmap(0x400000, pteSize) - - checkMappings(t, pt, nil) -} - -func TestReadOnly(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.Read}}, - }) -} - -func TestReadWrite(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, - }) -} - -func TestSerialEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map two sequential entries. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) - pt.Map(0x401000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, - {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: hostarch.ReadWrite}}, - }) -} - -func TestSpanningEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Span a pgd with two pages. - pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.Read}}, - {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: hostarch.Read}}, - }) -} - -func TestSparseEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map two entries in different pgds. - pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}}, - {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: hostarch.Read}}, - }) -} diff --git a/pkg/ring0/pagetables/pagetables_unsafe_state_autogen.go b/pkg/ring0/pagetables/pagetables_unsafe_state_autogen.go new file mode 100644 index 000000000..4c4540603 --- /dev/null +++ b/pkg/ring0/pagetables/pagetables_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pagetables diff --git a/pkg/ring0/pagetables/pagetables_x86_state_autogen.go b/pkg/ring0/pagetables/pagetables_x86_state_autogen.go new file mode 100644 index 000000000..36d1fad72 --- /dev/null +++ b/pkg/ring0/pagetables/pagetables_x86_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +//go:build (386 || amd64) && (i386 || amd64) +// +build 386 amd64 +// +build i386 amd64 + +package pagetables diff --git a/pkg/ring0/pagetables/walker_empty_amd64.go b/pkg/ring0/pagetables/walker_empty_amd64.go new file mode 100644 index 000000000..5a6ac02eb --- /dev/null +++ b/pkg/ring0/pagetables/walker_empty_amd64.go @@ -0,0 +1,266 @@ +//go:build amd64 +// +build amd64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *emptyWalker) iterateRangeCanonical(start, end uintptr) bool { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = emptynext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = emptynext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = emptynext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < emptynext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = emptynext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = emptynext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = emptynext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < emptynext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = emptynext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start&^(pteSize-1)), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type emptyWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor emptyVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *emptyWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func emptynext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_empty_arm64.go b/pkg/ring0/pagetables/walker_empty_arm64.go new file mode 100644 index 000000000..3ab92f0ec --- /dev/null +++ b/pkg/ring0/pagetables/walker_empty_arm64.go @@ -0,0 +1,276 @@ +//go:build arm64 +// +build arm64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *emptyWalker) iterateRangeCanonical(start, end uintptr) bool { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = emptynext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = emptynext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = emptynext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < emptynext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = emptynext(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = emptynext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = emptynext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < emptynext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = emptynext(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type emptyWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor emptyVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *emptyWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func emptynext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_lookup_amd64.go b/pkg/ring0/pagetables/walker_lookup_amd64.go new file mode 100644 index 000000000..ad0cb8bb4 --- /dev/null +++ b/pkg/ring0/pagetables/walker_lookup_amd64.go @@ -0,0 +1,266 @@ +//go:build amd64 +// +build amd64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *lookupWalker) iterateRangeCanonical(start, end uintptr) bool { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = lookupnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = lookupnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = lookupnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < lookupnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = lookupnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = lookupnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = lookupnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < lookupnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = lookupnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start&^(pteSize-1)), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type lookupWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor lookupVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *lookupWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func lookupnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_lookup_arm64.go b/pkg/ring0/pagetables/walker_lookup_arm64.go new file mode 100644 index 000000000..2b989b3f1 --- /dev/null +++ b/pkg/ring0/pagetables/walker_lookup_arm64.go @@ -0,0 +1,276 @@ +//go:build arm64 +// +build arm64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *lookupWalker) iterateRangeCanonical(start, end uintptr) bool { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = lookupnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = lookupnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = lookupnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < lookupnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = lookupnext(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = lookupnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = lookupnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < lookupnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = lookupnext(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type lookupWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor lookupVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *lookupWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func lookupnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_map_amd64.go b/pkg/ring0/pagetables/walker_map_amd64.go new file mode 100644 index 000000000..4e0b959a5 --- /dev/null +++ b/pkg/ring0/pagetables/walker_map_amd64.go @@ -0,0 +1,266 @@ +//go:build amd64 +// +build amd64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *mapWalker) iterateRangeCanonical(start, end uintptr) bool { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = mapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = mapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = mapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < mapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = mapnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = mapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = mapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < mapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = mapnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start&^(pteSize-1)), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type mapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor mapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *mapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func mapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_map_arm64.go b/pkg/ring0/pagetables/walker_map_arm64.go new file mode 100644 index 000000000..6d70b7091 --- /dev/null +++ b/pkg/ring0/pagetables/walker_map_arm64.go @@ -0,0 +1,276 @@ +//go:build arm64 +// +build arm64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *mapWalker) iterateRangeCanonical(start, end uintptr) bool { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = mapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = mapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = mapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < mapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = mapnext(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = mapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = mapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < mapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = mapnext(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type mapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor mapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *mapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func mapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_unmap_amd64.go b/pkg/ring0/pagetables/walker_unmap_amd64.go new file mode 100644 index 000000000..25d3a1710 --- /dev/null +++ b/pkg/ring0/pagetables/walker_unmap_amd64.go @@ -0,0 +1,266 @@ +//go:build amd64 +// +build amd64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *unmapWalker) iterateRangeCanonical(start, end uintptr) bool { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = unmapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = unmapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = unmapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < unmapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = unmapnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = unmapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = unmapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < unmapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = unmapnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start&^(pteSize-1)), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type unmapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor unmapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *unmapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func unmapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/pagetables/walker_unmap_arm64.go b/pkg/ring0/pagetables/walker_unmap_arm64.go new file mode 100644 index 000000000..7ad732e39 --- /dev/null +++ b/pkg/ring0/pagetables/walker_unmap_arm64.go @@ -0,0 +1,276 @@ +//go:build arm64 +// +build arm64 + +package pagetables + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *unmapWalker) iterateRangeCanonical(start, end uintptr) bool { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = unmapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = unmapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + if pudEntry.Valid() { + start = unmapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < unmapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = unmapnext(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = unmapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + if pmdEntry.Valid() { + start = unmapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < unmapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = unmapnext(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + if !w.visitor.visit(uintptr(start), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } + return true +} + +// Walker walks page tables. +type unmapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor unmapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *unmapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func unmapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/ring0/ring0_amd64_state_autogen.go b/pkg/ring0/ring0_amd64_state_autogen.go new file mode 100644 index 000000000..2bd773254 --- /dev/null +++ b/pkg/ring0/ring0_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package ring0 diff --git a/pkg/ring0/ring0_arm64_state_autogen.go b/pkg/ring0/ring0_arm64_state_autogen.go new file mode 100644 index 000000000..d87b068ec --- /dev/null +++ b/pkg/ring0/ring0_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package ring0 diff --git a/pkg/ring0/ring0_impl_amd64_state_autogen.go b/pkg/ring0/ring0_impl_amd64_state_autogen.go new file mode 100644 index 000000000..81a3dd14f --- /dev/null +++ b/pkg/ring0/ring0_impl_amd64_state_autogen.go @@ -0,0 +1,8 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && (386 || amd64) +// +build amd64 +// +build amd64 +// +build 386 amd64 + +package ring0 diff --git a/pkg/ring0/ring0_impl_arm64_state_autogen.go b/pkg/ring0/ring0_impl_arm64_state_autogen.go new file mode 100644 index 000000000..d87b068ec --- /dev/null +++ b/pkg/ring0/ring0_impl_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package ring0 diff --git a/pkg/ring0/ring0_state_autogen.go b/pkg/ring0/ring0_state_autogen.go new file mode 100644 index 000000000..327aba163 --- /dev/null +++ b/pkg/ring0/ring0_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ring0 diff --git a/pkg/ring0/ring0_unsafe_state_autogen.go b/pkg/ring0/ring0_unsafe_state_autogen.go new file mode 100644 index 000000000..327aba163 --- /dev/null +++ b/pkg/ring0/ring0_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ring0 diff --git a/pkg/ring0/x86.go b/pkg/ring0/x86.go deleted file mode 100644 index 9a98703da..000000000 --- a/pkg/ring0/x86.go +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build 386 || amd64 -// +build 386 amd64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/cpuid" -) - -// Useful bits. -const ( - _CR0_PE = 1 << 0 - _CR0_ET = 1 << 4 - _CR0_NE = 1 << 5 - _CR0_AM = 1 << 18 - _CR0_PG = 1 << 31 - - _CR4_PSE = 1 << 4 - _CR4_PAE = 1 << 5 - _CR4_PGE = 1 << 7 - _CR4_OSFXSR = 1 << 9 - _CR4_OSXMMEXCPT = 1 << 10 - _CR4_FSGSBASE = 1 << 16 - _CR4_PCIDE = 1 << 17 - _CR4_OSXSAVE = 1 << 18 - _CR4_SMEP = 1 << 20 - - _RFLAGS_AC = 1 << 18 - _RFLAGS_NT = 1 << 14 - _RFLAGS_IOPL0 = 1 << 12 - _RFLAGS_IOPL1 = 1 << 13 - _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1 - _RFLAGS_DF = 1 << 10 - _RFLAGS_IF = 1 << 9 - _RFLAGS_STEP = 1 << 8 - _RFLAGS_RESERVED = 1 << 1 - - _EFER_SCE = 0x001 - _EFER_LME = 0x100 - _EFER_LMA = 0x400 - _EFER_NX = 0x800 - - _MSR_STAR = 0xc0000081 - _MSR_LSTAR = 0xc0000082 - _MSR_CSTAR = 0xc0000083 - _MSR_SYSCALL_MASK = 0xc0000084 - _MSR_PLATFORM_INFO = 0xce - _MSR_MISC_FEATURES = 0x140 - - _PLATFORM_INFO_CPUID_FAULT = 1 << 31 - - _MISC_FEATURE_CPUID_TRAP = 0x1 -) - -const ( - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = _RFLAGS_RESERVED - - // UserFlagsSet are always set in userspace. - // - // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege - // level. The Current Privilege Level (CPL) of the task must be less - // than or equal to the IOPL in order for the task or program to access - // I/O ports. - // - // Here, _RFLAGS_IOPL0 is used only to determine whether the task is - // running in the kernel or userspace mode. In the user mode, the CPL is - // always 3 and it doesn't matter what IOPL is set if it is bellow CPL. - // - // We need to have one bit which will be always different in user and - // kernel modes. And we have to remember that even though we have - // KernelFlagsClear, we still can see some of these flags in the kernel - // mode. This can happen when the goruntime switches on a goroutine - // which has been saved in the host mode. On restore, the popf - // instruction is used to restore flags and this means that all flags - // what the goroutine has in the host mode will be restored in the - // kernel mode. - // - // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set - // it in the user mode. So if this flag is set, the task is running in - // the user mode and if it isn't set, the task is running in the kernel - // mode. - UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0 - - // KernelFlagsClear should always be clear in the kernel. - KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT - - // UserFlagsClear are always cleared in userspace. - UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1 -) - -// IsKernelFlags returns true if rflags coresponds to the kernel mode. -// -// go:nosplit -func IsKernelFlags(rflags uint64) bool { - return rflags&_RFLAGS_IOPL0 == 0 -} - -// Vector is an exception vector. -type Vector uintptr - -// Exception vectors. -const ( - DivideByZero Vector = iota - Debug - NMI - Breakpoint - Overflow - BoundRangeExceeded - InvalidOpcode - DeviceNotAvailable - DoubleFault - CoprocessorSegmentOverrun - InvalidTSS - SegmentNotPresent - StackSegmentFault - GeneralProtectionFault - PageFault - _ - X87FloatingPointException - AlignmentCheck - MachineCheck - SIMDFloatingPointException - VirtualizationException - SecurityException = 0x1e - SyscallInt80 = 0x80 - _NR_INTERRUPTS = 0x100 -) - -// System call vectors. -const ( - Syscall Vector = _NR_INTERRUPTS -) - -// VirtualAddressBits returns the number bits available for virtual addresses. -// -// Note that sign-extension semantics apply to the highest order bit. -// -// FIXME(b/69382326): This should use the cpuid passed to Init. -func VirtualAddressBits() uint32 { - ax, _, _, _ := cpuid.HostID(0x80000008, 0) - return (ax >> 8) & 0xff -} - -// PhysicalAddressBits returns the number of bits available for physical addresses. -// -// FIXME(b/69382326): This should use the cpuid passed to Init. -func PhysicalAddressBits() uint32 { - ax, _, _, _ := cpuid.HostID(0x80000008, 0) - return ax & 0xff -} - -// Selector is a segment Selector. -type Selector uint16 - -// SegmentDescriptor is a segment descriptor. -type SegmentDescriptor struct { - bits [2]uint32 -} - -// descriptorTable is a collection of descriptors. -type descriptorTable [32]SegmentDescriptor - -// SegmentDescriptorFlags are typed flags within a descriptor. -type SegmentDescriptorFlags uint32 - -// SegmentDescriptorFlag declarations. -const ( - SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). - SegmentDescriptorWrite = 1 << 9 // Write permission. - SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. - SegmentDescriptorExecute = 1 << 11 // Execute permission. - SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. - SegmentDescriptorPresent = 1 << 15 // Present. - SegmentDescriptorAVL = 1 << 20 // Available. - SegmentDescriptorLong = 1 << 21 // Long mode. - SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. - SegmentDescriptorG = 1 << 23 // Granularity: page or byte. -) - -// Base returns the descriptor's base linear address. -func (d *SegmentDescriptor) Base() uint32 { - return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 -} - -// Limit returns the descriptor size. -func (d *SegmentDescriptor) Limit() uint32 { - l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 - if d.bits[1]&uint32(SegmentDescriptorG) != 0 { - l <<= 12 - l |= 0xFFF - } - return l -} - -// Flags returns descriptor flags. -func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { - return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) -} - -// DPL returns the descriptor privilege level. -func (d *SegmentDescriptor) DPL() int { - return int((d.bits[1] >> 13) & 3) -} - -func (d *SegmentDescriptor) setNull() { - d.bits[0] = 0 - d.bits[1] = 0 -} - -func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { - flags |= SegmentDescriptorPresent - if limit>>12 != 0 { - limit >>= 12 - flags |= SegmentDescriptorG - } - d.bits[0] = base<<16 | limit&0xFFFF - d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 -} - -func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorDB| - SegmentDescriptorExecute| - SegmentDescriptorSystem) -} - -func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorG| - SegmentDescriptorLong| - SegmentDescriptorExecute| - SegmentDescriptorSystem) -} - -func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorWrite| - SegmentDescriptorSystem) -} - -// setHi is only used for the TSS segment, which is magically 64-bits. -func (d *SegmentDescriptor) setHi(base uint32) { - d.bits[0] = base - d.bits[1] = 0 -} - -// Gate64 is a 64-bit task, trap, or interrupt gate. -type Gate64 struct { - bits [4]uint32 -} - -// idt64 is a 64-bit interrupt descriptor table. -type idt64 [_NR_INTERRUPTS]Gate64 - -func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { - g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF - g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 - g.bits[2] = uint32(rip >> 32) -} - -func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { - g.setInterrupt(cs, rip, dpl, ist) - g.bits[1] |= 1 << 8 -} - -// TaskState64 is a 64-bit task state structure. -type TaskState64 struct { - _ uint32 - rsp0Lo, rsp0Hi uint32 - rsp1Lo, rsp1Hi uint32 - rsp2Lo, rsp2Hi uint32 - _ [2]uint32 - ist1Lo, ist1Hi uint32 - ist2Lo, ist2Hi uint32 - ist3Lo, ist3Hi uint32 - ist4Lo, ist4Hi uint32 - ist5Lo, ist5Hi uint32 - ist6Lo, ist6Hi uint32 - ist7Lo, ist7Hi uint32 - _ [2]uint32 - _ uint16 - ioPerm uint16 -} diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD deleted file mode 100644 index 0a045fc8e..000000000 --- a/pkg/safecopy/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safecopy", - srcs = [ - "atomic_amd64.s", - "atomic_arm64.s", - "memclr_amd64.s", - "memclr_arm64.s", - "memcpy_amd64.s", - "memcpy_arm64.s", - "safecopy.go", - "safecopy_unsafe.go", - "sighandler_amd64.s", - "sighandler_arm64.s", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/errors", - "//pkg/errors/linuxerr", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "safecopy_test", - srcs = [ - "safecopy_test.go", - ], - library = ":safecopy", - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/safecopy/LICENSE b/pkg/safecopy/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/safecopy/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/safecopy/safecopy_state_autogen.go b/pkg/safecopy/safecopy_state_autogen.go new file mode 100644 index 000000000..791eef959 --- /dev/null +++ b/pkg/safecopy/safecopy_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package safecopy diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go deleted file mode 100644 index 55743e69c..000000000 --- a/pkg/safecopy/safecopy_test.go +++ /dev/null @@ -1,568 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safecopy - -import ( - "bytes" - "fmt" - "io/ioutil" - "math/rand" - "testing" - "unsafe" - - "golang.org/x/sys/unix" -) - -// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular -// dependency. -const pageSize = 4096 - -func initRandom(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func randBuf(size int) []byte { - b := make([]byte, size) - initRandom(b) - return b -} - -func TestCopyInSuccess(t *testing.T) { - // Test that CopyIn does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyIn(b, unsafe.Pointer(&a[0])) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopyOutSuccess(t *testing.T) { - // Test that CopyOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyOut(unsafe.Pointer(&b[0]), a) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopySuccess(t *testing.T) { - // Test that Copy does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestZeroOutSuccess(t *testing.T) { - // Test that ZeroOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := make([]byte, bufLen) - b := randBuf(bufLen) - - n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestSwapUint32Success(t *testing.T) { - // Test that SwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := SwapUint32(unsafe.Pointer(&val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestSwapUint32AlignmentError(t *testing.T) { - // Test that SwapUint32 returns an AlignmentError when passed an unaligned - // address. - data := make([]byte, 8) // 2 * sizeof(uint32). - alignedIndex := uintptr(0) - if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 { - alignedIndex = 4 - offset - } - ptr := unsafe.Pointer(&data[alignedIndex+1]) - want := AlignmentError{Addr: uintptr(ptr), Alignment: 4} - if _, err := SwapUint32(ptr, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestSwapUint64Success(t *testing.T) { - // Test that SwapUint64 does not return an error when the page is - // accessible. - before := uint64(rand.Int63()) - after := uint64(rand.Int63()) - // "The first word in ... an allocated struct or slice can be relied upon - // to be 64-bit aligned." - sync/atomic docs - data := new(struct{ val uint64 }) - data.val = before - - old, err := SwapUint64(unsafe.Pointer(&data.val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if data.val != after { - t.Errorf("Unexpected new value: got %v, want %v", data.val, after) - } -} - -func TestSwapUint64AlignmentError(t *testing.T) { - // Test that SwapUint64 returns an AlignmentError when passed an unaligned - // address. - data := make([]byte, 16) // 2 * sizeof(uint64). - alignedIndex := uintptr(0) - if offset := uintptr(unsafe.Pointer(&data[0])) % 8; offset != 0 { - alignedIndex = 8 - offset - } - ptr := unsafe.Pointer(&data[alignedIndex+1]) - want := AlignmentError{Addr: uintptr(ptr), Alignment: 8} - if _, err := SwapUint64(ptr, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestCompareAndSwapUint32Success(t *testing.T) { - // Test that CompareAndSwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestCompareAndSwapUint32AlignmentError(t *testing.T) { - // Test that CompareAndSwapUint32 returns an AlignmentError when passed an - // unaligned address. - data := make([]byte, 8) // 2 * sizeof(uint32). - alignedIndex := uintptr(0) - if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 { - alignedIndex = 4 - offset - } - ptr := unsafe.Pointer(&data[alignedIndex+1]) - want := AlignmentError{Addr: uintptr(ptr), Alignment: 4} - if _, err := CompareAndSwapUint32(ptr, 0, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -// withSegvErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGSEGV when accessed. -func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { - mapping, err := unix.Mmap(-1, 0, 2*pageSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_ANONYMOUS|unix.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer unix.Munmap(mapping) - if err := unix.Mprotect(mapping[pageSize:], unix.PROT_NONE); err != nil { - t.Fatalf("Mprotect failed: %v", err) - } - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -// withBusErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGBUS when accessed. -func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - if err := f.Truncate(pageSize); err != nil { - t.Fatalf("Truncate failed: %v", err) - } - mapping, err := unix.Mmap(int(f.Fd()), 0, 2*pageSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer unix.Munmap(mapping) - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -func TestCopyInSegvError(t *testing.T) { - // Test that CopyIn returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyInBusError(t *testing.T) { - // Test that CopyIn returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutSegvError(t *testing.T) { - // Test that CopyOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutBusError(t *testing.T) { - // Test that CopyOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying from a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceBusError(t *testing.T) { - // Test that Copy returns a BusError when copying from a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying to a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationBusError(t *testing.T) { - // Test that Copy returns a BusError when copying to a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestZeroOutSegvError(t *testing.T) { - // Test that ZeroOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestZeroOutBusError(t *testing.T) { - // Test that ZeroOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault]) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestSwapUint32SegvError(t *testing.T) { - // Test that SwapUint32 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint32BusError(t *testing.T) { - // Test that SwapUint32 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64SegvError(t *testing.T) { - // Test that SwapUint64 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64BusError(t *testing.T) { - // Test that SwapUint64 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32SegvError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a SegvError when reaching a page - // that signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32BusError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a BusError when reaching a page - // that signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[pageSize])) - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} diff --git a/pkg/safecopy/safecopy_unsafe_state_autogen.go b/pkg/safecopy/safecopy_unsafe_state_autogen.go new file mode 100644 index 000000000..791eef959 --- /dev/null +++ b/pkg/safecopy/safecopy_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package safecopy diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD deleted file mode 100644 index 2c7cc8769..000000000 --- a/pkg/safemem/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safemem", - srcs = [ - "block_unsafe.go", - "io.go", - "safemem.go", - "seq_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/gohacks", - "//pkg/safecopy", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "safemem_test", - size = "small", - srcs = [ - "io_test.go", - "seq_test.go", - ], - library = ":safemem", -) diff --git a/pkg/safemem/io_test.go b/pkg/safemem/io_test.go deleted file mode 100644 index 629741bee..000000000 --- a/pkg/safemem/io_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "io" - "testing" -) - -func makeBlocks(slices ...[]byte) []Block { - blocks := make([]Block, 0, len(slices)) - for _, s := range slices { - blocks = append(blocks, BlockFromSafeSlice(s)) - } - return blocks -} - -func TestFromIOReaderFullRead(t *testing.T) { - r := FromIOReader{bytes.NewBufferString("foobar")} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type eofHidingReader struct { - Reader io.Reader -} - -func (r eofHidingReader) Read(dst []byte) (int, error) { - n, err := r.Reader.Read(dst) - if err == io.EOF { - return n, nil - } - return n, err -} - -func TestFromIOReaderPartialRead(t *testing.T) { - r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the eofHidingReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type singleByteReader struct { - Reader io.Reader -} - -func (r singleByteReader) Read(dst []byte) (int, error) { - if len(dst) == 0 { - return r.Reader.Read(dst) - } - return r.Reader.Read(dst[:1]) -} - -func TestSingleByteReader(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the singleByteReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestReadFullToBlocks(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) - // ReadFullToBlocks should call into FromIOReader => singleByteReader - // repeatedly until dsts is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestFromIOWriterFullWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&dst} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type limitedWriter struct { - Writer io.Writer - Done int - Limit int -} - -func (w *limitedWriter) Write(src []byte) (int, error) { - count := len(src) - if count > (w.Limit - w.Done) { - count = w.Limit - w.Done - } - n, err := w.Writer.Write(src[:count]) - w.Done += n - return n, err -} - -func TestFromIOWriterPartialWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&limitedWriter{&dst, 0, 4}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the limitedWriter returns (1, nil) for a - // 3-byte write. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type singleByteWriter struct { - Writer io.Writer -} - -func (w singleByteWriter) Write(src []byte) (int, error) { - if len(src) == 0 { - return w.Writer.Write(src) - } - return w.Writer.Write(src[:1]) -} - -func TestSingleByteWriter(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the singleByteWriter returns (1, nil) - // for a 3-byte write. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestWriteFullToBlocks(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) - // WriteFullToBlocks should call into FromIOWriter => singleByteWriter - // repeatedly until srcs is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} diff --git a/pkg/safemem/safemem_state_autogen.go b/pkg/safemem/safemem_state_autogen.go new file mode 100644 index 000000000..66d53f22d --- /dev/null +++ b/pkg/safemem/safemem_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package safemem diff --git a/pkg/safemem/safemem_unsafe_state_autogen.go b/pkg/safemem/safemem_unsafe_state_autogen.go new file mode 100644 index 000000000..66d53f22d --- /dev/null +++ b/pkg/safemem/safemem_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package safemem diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go deleted file mode 100644 index de34005e9..000000000 --- a/pkg/safemem/seq_test.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "reflect" - "testing" -) - -func TestBlockSeqOfEmptyBlock(t *testing.T) { - bs := BlockSeqOf(Block{}) - if !bs.IsEmpty() { - t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs) - } -} - -func TestBlockSeqOfNonemptyBlock(t *testing.T) { - b := BlockFromSafeSlice(make([]byte, 1)) - bs := BlockSeqOf(b) - if bs.IsEmpty() { - t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs) - } - if head := bs.Head(); head != b { - t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b) - } - if tail := bs.Tail(); !tail.IsEmpty() { - t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail) - } -} - -type blockSeqTest struct { - desc string - - pieces []string - haveOffset bool - offset uint64 - haveLimit bool - limit uint64 - - want string -} - -func (t blockSeqTest) NonEmptyByteSlices() [][]byte { - // t is a value, so we can mutate it freely. - slices := make([][]byte, 0, len(t.pieces)) - for _, str := range t.pieces { - if t.haveOffset { - strOff := t.offset - if strOff > uint64(len(str)) { - strOff = uint64(len(str)) - } - str = str[strOff:] - t.offset -= strOff - } - if t.haveLimit { - strLim := t.limit - if strLim > uint64(len(str)) { - strLim = uint64(len(str)) - } - str = str[:strLim] - t.limit -= strLim - } - if len(str) != 0 { - slices = append(slices, []byte(str)) - } - } - return slices -} - -func (t blockSeqTest) BlockSeq() BlockSeq { - blocks := make([]Block, 0, len(t.pieces)) - for _, str := range t.pieces { - blocks = append(blocks, BlockFromSafeSlice([]byte(str))) - } - bs := BlockSeqFromSlice(blocks) - if t.haveOffset { - bs = bs.DropFirst64(t.offset) - } - if t.haveLimit { - bs = bs.TakeFirst64(t.limit) - } - return bs -} - -var blockSeqTests = []blockSeqTest{ - { - desc: "Empty sequence", - }, - { - desc: "Sequence of length 1", - pieces: []string{"foobar"}, - want: "foobar", - }, - { - desc: "Sequence of length 2", - pieces: []string{"foo", "bar"}, - want: "foobar", - }, - { - desc: "Empty Blocks", - pieces: []string{"", "foo", "", "", "bar", ""}, - want: "foobar", - }, - { - desc: "Sequence with non-zero offset", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - want: "obar", - }, - { - desc: "Sequence with non-maximal limit", - pieces: []string{"foo", "bar"}, - haveLimit: true, - limit: 5, - want: "fooba", - }, - { - desc: "Sequence with offset and limit", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - haveLimit: true, - limit: 3, - want: "oba", - }, -} - -func TestBlockSeqNumBytes(t *testing.T) { - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want { - t.Errorf("NumBytes: got %d, wanted %d", got, want) - } - }) - } -} - -func TestBlockSeqIterBlocks(t *testing.T) { - // Tests BlockSeq iteration using Head/Tail. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - // "Note that a non-nil empty slice and a nil slice ... are not - // deeply equal." - reflect - slices := make([][]byte, 0, 0) - for !srcs.IsEmpty() { - src := srcs.Head() - slices = append(slices, src.ToSlice()) - nextSrcs := srcs.Tail() - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want { - t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) { - t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices) - } - }) - } -} - -func TestBlockSeqIterBytes(t *testing.T) { - // Tests BlockSeq iteration using Head/DropFirst. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - var dst bytes.Buffer - for !srcs.IsEmpty() { - src := srcs.Head() - var b [1]byte - n, err := Copy(BlockFromSafeSlice(b[:]), src) - if n != 1 || err != nil { - t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err) - } - dst.WriteByte(b[0]) - nextSrcs := srcs.DropFirst(1) - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want { - t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if got := string(dst.Bytes()); got != test.want { - t.Errorf("Copied string: got %q, wanted %q", got, test.want) - } - }) - } -} - -func TestBlockSeqDropBeyondLimit(t *testing.T) { - blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))} - bs := BlockSeqFromSlice(blocks) - if got, want := bs.NumBytes(), uint64(4); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.TakeFirst(1) - if got, want := bs.NumBytes(), uint64(1); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.DropFirst(2) - if got, want := bs.NumBytes(), uint64(0); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } -} diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD deleted file mode 100644 index 169fc0ac3..000000000 --- a/pkg/seccomp/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_binary( - name = "victim", - testonly = 1, - srcs = [ - "seccomp_test_victim.go", - "seccomp_test_victim_amd64.go", - "seccomp_test_victim_arm64.go", - ], - nogo = False, - deps = [ - ":seccomp", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_embed_data( - name = "victim_data", - testonly = 1, - src = "victim", - package = "seccomp", - var = "victimData", -) - -go_library( - name = "seccomp", - srcs = [ - "seccomp.go", - "seccomp_amd64.go", - "seccomp_arm64.go", - "seccomp_rules.go", - "seccomp_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/bpf", - "//pkg/log", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "seccomp_test", - size = "small", - srcs = [ - "seccomp_test.go", - ":victim_data", - ], - library = ":seccomp", - deps = [ - "//pkg/abi/linux", - "//pkg/bpf", - "//pkg/hostarch", - ], -) diff --git a/pkg/seccomp/seccomp_amd64_state_autogen.go b/pkg/seccomp/seccomp_amd64_state_autogen.go new file mode 100644 index 000000000..6c95fd6fb --- /dev/null +++ b/pkg/seccomp/seccomp_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package seccomp diff --git a/pkg/seccomp/seccomp_arm64_state_autogen.go b/pkg/seccomp/seccomp_arm64_state_autogen.go new file mode 100644 index 000000000..e6f1039cd --- /dev/null +++ b/pkg/seccomp/seccomp_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package seccomp diff --git a/pkg/seccomp/seccomp_state_autogen.go b/pkg/seccomp/seccomp_state_autogen.go new file mode 100644 index 000000000..e16b5d7c2 --- /dev/null +++ b/pkg/seccomp/seccomp_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seccomp diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go deleted file mode 100644 index 68feddf31..000000000 --- a/pkg/seccomp/seccomp_test.go +++ /dev/null @@ -1,1059 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seccomp - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "math" - "math/rand" - "os" - "os/exec" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/bpf" - "gvisor.dev/gvisor/pkg/hostarch" -) - -// newVictim makes a victim binary. -func newVictim() (string, error) { - f, err := ioutil.TempFile("", "victim") - if err != nil { - return "", err - } - defer f.Close() - path := f.Name() - if _, err := io.Copy(f, bytes.NewBuffer(victimData)); err != nil { - os.Remove(path) - return "", err - } - if err := os.Chmod(path, 0755); err != nil { - os.Remove(path) - return "", err - } - return path, nil -} - -// dataAsInput converts a linux.SeccompData to a bpf.Input. -func dataAsInput(d *linux.SeccompData) bpf.Input { - buf := make([]byte, d.SizeBytes()) - d.MarshalUnsafe(buf) - return bpf.InputBytes{ - Data: buf, - Order: hostarch.ByteOrder, - } -} - -func TestBasic(t *testing.T) { - type spec struct { - // desc is the test's description. - desc string - - // data is the input data. - data linux.SeccompData - - // want is the expected return value of the BPF program. - want linux.BPFAction - } - - for _, test := range []struct { - name string - ruleSets []RuleSet - defaultAction linux.BPFAction - badArchAction linux.BPFAction - specs []spec - }{ - { - name: "Single syscall", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{1: {}}, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "syscall allowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "syscall disallowed", - data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "Multiple rulesets", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - EqualTo(0x1), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - { - Rules: SyscallRules{ - 1: {}, - 2: {}, - }, - Action: linux.SECCOMP_RET_TRAP, - }, - }, - defaultAction: linux.SECCOMP_RET_KILL_THREAD, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "allowed (1a)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "allowed (1b)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "syscall 1 matched 2nd rule", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "no match", - data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_KILL_THREAD, - }, - }, - }, - { - name: "Multiple syscalls", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - 3: {}, - 5: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "allowed (1)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "allowed (3)", - data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "allowed (5)", - data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "disallowed (0)", - data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "disallowed (2)", - data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "disallowed (4)", - data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "disallowed (6)", - data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "disallowed (100)", - data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "Wrong architecture", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arch (123)", - data: linux.SeccompData{Nr: 1, Arch: 123}, - want: linux.SECCOMP_RET_KILL_THREAD, - }, - }, - }, - { - name: "Syscall disallowed", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "action trap", - data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "Syscall arguments", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - MatchAny{}, - EqualTo(0xf), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "allowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "disallowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "Multiple arguments", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - EqualTo(0xf), - }, - { - EqualTo(0xe), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "match first rule", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "match 2nd rule", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}}, - want: linux.SECCOMP_RET_ALLOW, - }, - }, - }, - { - name: "EqualTo", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - EqualTo(0), - EqualTo(math.MaxUint64 - 1), - EqualTo(math.MaxUint32), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "argument allowed (all match)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "argument disallowed (one mismatch)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - Args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "argument disallowed (multiple mismatch)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "NotEqual", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - NotEqual(0x7aabbccdd), - NotEqual(math.MaxUint64 - 1), - NotEqual(math.MaxUint32), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arg allowed", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, - }, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (one equal)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (all equal)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "GreaterThan", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - // 4294967298 - // Both upper 32 bits and lower 32 bits are non-zero. - // 00000000000000000000000000000010 - // 00000000000000000000000000000010 - GreaterThan(0x00000002_00000002), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "high 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits equal, low 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits equal, low 32bits equal", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits equal, low 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "GreaterThan (multi)", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - GreaterThan(0xf), - GreaterThan(0xabcd000d), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arg allowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (first arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (first arg smaller)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (second arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (second arg smaller)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "GreaterThanOrEqual", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - // 4294967298 - // Both upper 32 bits and lower 32 bits are non-zero. - // 00000000000000000000000000000010 - // 00000000000000000000000000000010 - GreaterThanOrEqual(0x00000002_00000002), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "high 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits equal, low 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits equal, low 32bits equal", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits equal, low 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "GreaterThanOrEqual (multi)", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - GreaterThanOrEqual(0xf), - GreaterThanOrEqual(0xabcd000d), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arg allowed (both greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg allowed (first arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (first arg smaller)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg allowed (second arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (second arg smaller)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (both arg smaller)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "LessThan", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - // 4294967298 - // Both upper 32 bits and lower 32 bits are non-zero. - // 00000000000000000000000000000010 - // 00000000000000000000000000000010 - LessThan(0x00000002_00000002), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "high 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits equal, low 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits equal, low 32bits equal", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits equal, low 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, - want: linux.SECCOMP_RET_ALLOW, - }, - }, - }, - { - name: "LessThan (multi)", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - LessThan(0x1), - LessThan(0xabcd000d), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arg allowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (first arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (first arg greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (second arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (second arg greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (both arg greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "LessThanOrEqual", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - // 4294967298 - // Both upper 32 bits and lower 32 bits are non-zero. - // 00000000000000000000000000000010 - // 00000000000000000000000000000010 - LessThanOrEqual(0x00000002_00000002), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "high 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits equal, low 32bits greater", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "high 32bits equal, low 32bits equal", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits equal, low 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "high 32bits less", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, - want: linux.SECCOMP_RET_ALLOW, - }, - }, - }, - - { - name: "LessThanOrEqual (multi)", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - LessThanOrEqual(0x1), - LessThanOrEqual(0xabcd000d), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arg allowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg allowed (first arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (first arg greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg allowed (second arg equal)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (second arg greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (both arg greater)", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "MaskedEqual", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - // x & 00000001 00000011 (0x103) == 00000000 00000001 (0x1) - // Input x must have lowest order bit set and - // must *not* have 8th or second lowest order bit set. - MaskedEqual(0x103, 0x1), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "arg allowed (low order mandatory bit)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - // 00000000 00000000 00000000 00000001 - Args: [6]uint64{0x1}, - }, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg allowed (low order optional bit)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - // 00000000 00000000 00000000 00000101 - Args: [6]uint64{0x5}, - }, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "arg disallowed (lowest order bit not set)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - // 00000000 00000000 00000000 00000010 - Args: [6]uint64{0x2}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (second lowest order bit set)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - // 00000000 00000000 00000000 00000011 - Args: [6]uint64{0x3}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "arg disallowed (8th bit set)", - data: linux.SeccompData{ - Nr: 1, - Arch: LINUX_AUDIT_ARCH, - // 00000000 00000000 00000001 00000000 - Args: [6]uint64{0x100}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - name: "Instruction Pointer", - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - RuleIP: EqualTo(0x7aabbccdd), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - badArchAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "allowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "disallowed", - data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - instrs, err := BuildProgram(test.ruleSets, test.defaultAction, test.badArchAction) - if err != nil { - t.Fatalf("BuildProgram() got error: %v", err) - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Fatalf("bpf.Compile() got error: %v", err) - } - for _, spec := range test.specs { - got, err := bpf.Exec(p, dataAsInput(&spec.data)) - if err != nil { - t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err) - } - if got != uint32(spec.want) { - // Include a decoded version of the program in output for debugging purposes. - decoded, _ := bpf.DecodeInstructions(instrs) - t.Fatalf("%s: got: %d, want: %d\nBPF Program\n%s", spec.desc, got, spec.want, decoded) - } - } - }) - } -} - -// TestRandom tests that randomly generated rules are encoded correctly. -func TestRandom(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - size := rand.Intn(50) + 1 - syscallRules := make(map[uintptr][]Rule) - for len(syscallRules) < size { - n := uintptr(rand.Intn(200)) - if _, ok := syscallRules[n]; !ok { - syscallRules[n] = []Rule{} - } - } - - t.Logf("Testing filters: %v", syscallRules) - instrs, err := BuildProgram([]RuleSet{ - { - Rules: syscallRules, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, linux.SECCOMP_RET_TRAP, linux.SECCOMP_RET_KILL_THREAD) - if err != nil { - t.Fatalf("buildProgram() got error: %v", err) - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Fatalf("bpf.Compile() got error: %v", err) - } - for i := uint32(0); i < 200; i++ { - data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH} - got, err := bpf.Exec(p, dataAsInput(&data)) - if err != nil { - t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) - continue - } - want := linux.SECCOMP_RET_TRAP - if _, ok := syscallRules[uintptr(i)]; ok { - want = linux.SECCOMP_RET_ALLOW - } - if got != uint32(want) { - t.Errorf("bpf.Exec() = %d, want: %d, for syscall %d", got, want, i) - } - } -} - -// TestReadDeal checks that a process dies when it trips over the filter and -// that it doesn't die when the filter is not triggered. -func TestRealDeal(t *testing.T) { - for _, test := range []struct { - die bool - want string - }{ - {die: true, want: "bad system call"}, - {die: false, want: "Syscall was allowed!!!"}, - } { - victim, err := newVictim() - if err != nil { - t.Fatalf("unable to get victim: %v", err) - } - defer os.Remove(victim) - dieFlag := fmt.Sprintf("-die=%v", test.die) - cmd := exec.Command(victim, dieFlag) - - out, err := cmd.CombinedOutput() - if test.die { - if err == nil { - t.Errorf("victim was not killed as expected, output: %s", out) - continue - } - // Depending on kernel version, either RET_TRAP or RET_KILL_PROCESS is - // used. RET_TRAP dumps reason for exit in output, while RET_KILL_PROCESS - // returns SIGSYS as exit status. - if !strings.Contains(string(out), test.want) && - !strings.Contains(err.Error(), test.want) { - t.Errorf("Victim error is wrong, got: %v, err: %v, want: %v", string(out), err, test.want) - continue - } - } else { - if err != nil { - t.Errorf("victim failed to execute, err: %v", err) - continue - } - if !strings.Contains(string(out), test.want) { - t.Errorf("Victim output is wrong, got: %v, want: %v", string(out), test.want) - continue - } - } - } -} - -// TestMerge ensures that empty rules are not erased when rules are merged. -func TestMerge(t *testing.T) { - for _, tst := range []struct { - name string - main []Rule - merge []Rule - want []Rule - }{ - { - name: "empty both", - main: nil, - merge: nil, - want: []Rule{{}, {}}, - }, - { - name: "empty main", - main: nil, - merge: []Rule{{}}, - want: []Rule{{}, {}}, - }, - { - name: "empty merge", - main: []Rule{{}}, - merge: nil, - want: []Rule{{}, {}}, - }, - } { - t.Run(tst.name, func(t *testing.T) { - mainRules := SyscallRules{1: tst.main} - mergeRules := SyscallRules{1: tst.merge} - mainRules.Merge(mergeRules) - if got, want := len(mainRules[1]), len(tst.want); got != want { - t.Errorf("wrong length, got: %d, want: %d", got, want) - } - for i, r := range mainRules[1] { - if r != tst.want[i] { - t.Errorf("result, got: %v, want: %v", r, tst.want[i]) - } - } - }) - } -} - -// TestAddRule ensures that empty rules are not erased when rules are added. -func TestAddRule(t *testing.T) { - rules := SyscallRules{1: {}} - rules.AddRule(1, Rule{}) - if got, want := len(rules[1]), 2; got != want { - t.Errorf("len(rules[1]), got: %d, want: %d", got, want) - } -} diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go deleted file mode 100644 index a96b1e327..000000000 --- a/pkg/seccomp/seccomp_test_victim.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Test binary used to test that seccomp filters are properly constructed and -// indeed kill the process on violation. -package main - -import ( - "flag" - "fmt" - "os" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/seccomp" -) - -func main() { - dieFlag := flag.Bool("die", false, "trips over the filter if true") - flag.Parse() - - syscalls := seccomp.SyscallRules{ - unix.SYS_ACCEPT: {}, - unix.SYS_BIND: {}, - unix.SYS_BRK: {}, - unix.SYS_CLOCK_GETTIME: {}, - unix.SYS_CLONE: {}, - unix.SYS_CLOSE: {}, - unix.SYS_DUP: {}, - unix.SYS_DUP3: {}, - unix.SYS_EPOLL_CREATE1: {}, - unix.SYS_EPOLL_CTL: {}, - unix.SYS_EPOLL_PWAIT: {}, - unix.SYS_EXIT: {}, - unix.SYS_EXIT_GROUP: {}, - unix.SYS_FALLOCATE: {}, - unix.SYS_FCHMOD: {}, - unix.SYS_FCNTL: {}, - unix.SYS_FSTAT: {}, - unix.SYS_FSYNC: {}, - unix.SYS_FTRUNCATE: {}, - unix.SYS_FUTEX: {}, - unix.SYS_GETDENTS64: {}, - unix.SYS_GETPEERNAME: {}, - unix.SYS_GETPID: {}, - unix.SYS_GETSOCKNAME: {}, - unix.SYS_GETSOCKOPT: {}, - unix.SYS_GETTID: {}, - unix.SYS_GETTIMEOFDAY: {}, - unix.SYS_LISTEN: {}, - unix.SYS_LSEEK: {}, - unix.SYS_MADVISE: {}, - unix.SYS_MINCORE: {}, - unix.SYS_MMAP: {}, - unix.SYS_MPROTECT: {}, - unix.SYS_MUNLOCK: {}, - unix.SYS_MUNMAP: {}, - unix.SYS_NANOSLEEP: {}, - unix.SYS_PPOLL: {}, - unix.SYS_PREAD64: {}, - unix.SYS_PSELECT6: {}, - unix.SYS_PWRITE64: {}, - unix.SYS_READ: {}, - unix.SYS_READLINKAT: {}, - unix.SYS_READV: {}, - unix.SYS_RECVMSG: {}, - unix.SYS_RENAMEAT: {}, - unix.SYS_RESTART_SYSCALL: {}, - unix.SYS_RT_SIGACTION: {}, - unix.SYS_RT_SIGPROCMASK: {}, - unix.SYS_RT_SIGRETURN: {}, - unix.SYS_SCHED_YIELD: {}, - unix.SYS_SENDMSG: {}, - unix.SYS_SETITIMER: {}, - unix.SYS_SET_ROBUST_LIST: {}, - unix.SYS_SETSOCKOPT: {}, - unix.SYS_SHUTDOWN: {}, - unix.SYS_SIGALTSTACK: {}, - unix.SYS_SOCKET: {}, - unix.SYS_SYNC_FILE_RANGE: {}, - unix.SYS_TGKILL: {}, - unix.SYS_UTIMENSAT: {}, - unix.SYS_WRITE: {}, - unix.SYS_WRITEV: {}, - } - - arch_syscalls(syscalls) - - die := *dieFlag - if !die { - syscalls[unix.SYS_OPENAT] = []seccomp.Rule{ - { - seccomp.EqualTo(10), - }, - } - } - - if err := seccomp.Install(syscalls); err != nil { - fmt.Printf("Failed to install seccomp: %v", err) - os.Exit(1) - } - fmt.Printf("Filters installed\n") - - unix.RawSyscall(unix.SYS_OPENAT, 10, 0, 0) - fmt.Printf("Syscall was allowed!!!\n") -} diff --git a/pkg/seccomp/seccomp_test_victim_amd64.go b/pkg/seccomp/seccomp_test_victim_amd64.go deleted file mode 100644 index 5c1ecc301..000000000 --- a/pkg/seccomp/seccomp_test_victim_amd64.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Test binary used to test that seccomp filters are properly constructed and -// indeed kill the process on violation. - -//go:build amd64 -// +build amd64 - -package main - -import ( - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/seccomp" -) - -func arch_syscalls(syscalls seccomp.SyscallRules) { - syscalls[unix.SYS_ARCH_PRCTL] = []seccomp.Rule{} - syscalls[unix.SYS_EPOLL_WAIT] = []seccomp.Rule{} - syscalls[unix.SYS_NEWFSTATAT] = []seccomp.Rule{} - syscalls[unix.SYS_OPEN] = []seccomp.Rule{} -} diff --git a/pkg/seccomp/seccomp_test_victim_arm64.go b/pkg/seccomp/seccomp_test_victim_arm64.go deleted file mode 100644 index 9647e2758..000000000 --- a/pkg/seccomp/seccomp_test_victim_arm64.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Test binary used to test that seccomp filters are properly constructed and -// indeed kill the process on violation. - -//go:build arm64 -// +build arm64 - -package main - -import ( - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/seccomp" -) - -func arch_syscalls(syscalls seccomp.SyscallRules) { - syscalls[unix.SYS_FSTATAT] = []seccomp.Rule{} -} diff --git a/pkg/seccomp/seccomp_unsafe_state_autogen.go b/pkg/seccomp/seccomp_unsafe_state_autogen.go new file mode 100644 index 000000000..e16b5d7c2 --- /dev/null +++ b/pkg/seccomp/seccomp_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seccomp diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD deleted file mode 100644 index 60f63c7a6..000000000 --- a/pkg/secio/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "secio", - srcs = [ - "full_reader.go", - "secio.go", - ], - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "secio_test", - size = "small", - srcs = ["secio_test.go"], - library = ":secio", -) diff --git a/pkg/secio/secio_state_autogen.go b/pkg/secio/secio_state_autogen.go new file mode 100644 index 000000000..372ac4b92 --- /dev/null +++ b/pkg/secio/secio_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package secio diff --git a/pkg/secio/secio_test.go b/pkg/secio/secio_test.go deleted file mode 100644 index d1d905187..000000000 --- a/pkg/secio/secio_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package secio - -import ( - "bytes" - "errors" - "io" - "io/ioutil" - "math" - "testing" -) - -var errEndOfBuffer = errors.New("write beyond end of buffer") - -// buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt. -// Reads beyond the end of the buffer return io.EOF. Writes beyond the end of -// the buffer return errEndOfBuffer. -type buffer struct { - Bytes []byte -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (b *buffer) ReadAt(dst []byte, off int64) (int, error) { - if off >= int64(len(b.Bytes)) { - return 0, io.EOF - } - n := copy(dst, b.Bytes[off:]) - if n < len(dst) { - return n, io.EOF - } - return n, nil -} - -// WriteAt implements io.WriterAt.WriteAt. -func (b *buffer) WriteAt(src []byte, off int64) (int, error) { - if off >= int64(len(b.Bytes)) { - return 0, errEndOfBuffer - } - n := copy(b.Bytes[off:], src) - if n < len(src) { - return n, errEndOfBuffer - } - return n, nil -} - -func newBufferString(s string) *buffer { - return &buffer{[]byte(s)} -} - -func TestOffsetReader(t *testing.T) { - buf := newBufferString("foobar") - r := NewOffsetReader(buf, 3) - dst, err := ioutil.ReadAll(r) - if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) - } -} - -func TestSectionReader(t *testing.T) { - buf := newBufferString("foobarbaz") - r := NewSectionReader(buf, 3, 3) - dst, err := ioutil.ReadAll(r) - if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr) - } -} - -func TestSectionReaderLimitOverflow(t *testing.T) { - // SectionReader behaves like OffsetReader when limit overflows int64. - buf := newBufferString("foobar") - r := NewSectionReader(buf, 3, math.MaxInt64) - dst, err := ioutil.ReadAll(r) - if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) - } -} - -func TestOffsetWriter(t *testing.T) { - buf := newBufferString("ABCDEF") - w := NewOffsetWriter(buf, 3) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} - -func TestSectionWriter(t *testing.T) { - buf := newBufferString("ABCDEFGHI") - w := NewSectionWriter(buf, 3, 3) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} - -func TestSectionWriterLimitOverflow(t *testing.T) { - // SectionWriter behaves like OffsetWriter when limit overflows int64. - buf := newBufferString("ABCDEF") - w := NewSectionWriter(buf, 3, math.MaxInt64) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD deleted file mode 100644 index f57ccc170..000000000 --- a/pkg/segment/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -go_template( - name = "generic_range", - srcs = ["range.go"], - types = [ - "T", - ], -) - -go_template( - name = "generic_set", - srcs = [ - "set.go", - "set_state.go", - ], - opt_consts = [ - "minDegree", - # trackGaps must either be 0 or 1. - "trackGaps", - ], - types = [ - "Key", - "Range", - "Value", - "Functions", - ], -) diff --git a/pkg/segment/set_state.go b/pkg/segment/set_state.go deleted file mode 100644 index 76de92591..000000000 --- a/pkg/segment/set_state.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -func (s *Set) saveRoot() *SegmentDataSlices { - return s.ExportSortedSlices() -} - -func (s *Set) loadRoot(sds *SegmentDataSlices) { - if err := s.ImportSortedSlices(sds); err != nil { - panic(err) - } -} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD deleted file mode 100644 index 131bf09b9..000000000 --- a/pkg/segment/test/BUILD +++ /dev/null @@ -1,68 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package( - default_visibility = ["//visibility:private"], - licenses = ["notice"], -) - -go_template_instance( - name = "int_range", - out = "int_range.go", - package = "segment", - template = "//pkg/segment:generic_range", - types = { - "T": "int", - }, -) - -go_template_instance( - name = "int_set", - out = "int_set.go", - package = "segment", - template = "//pkg/segment:generic_set", - types = { - "Key": "int", - "Range": "Range", - "Value": "int", - "Functions": "setFunctions", - }, -) - -go_template_instance( - name = "gap_set", - out = "gap_set.go", - consts = { - "trackGaps": "1", - }, - package = "segment", - prefix = "gap", - template = "//pkg/segment:generic_set", - types = { - "Key": "int", - "Range": "Range", - "Value": "int", - "Functions": "gapSetFunctions", - }, -) - -go_library( - name = "segment", - testonly = 1, - srcs = [ - "gap_set.go", - "int_range.go", - "int_set.go", - "set_functions.go", - ], - deps = [ - "//pkg/state", - ], -) - -go_test( - name = "segment_test", - size = "small", - srcs = ["segment_test.go"], - library = ":segment", -) diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go deleted file mode 100644 index 85fa19096..000000000 --- a/pkg/segment/test/segment_test.go +++ /dev/null @@ -1,865 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -import ( - "fmt" - "math/rand" - "reflect" - "testing" -) - -const ( - // testSize is the baseline number of elements inserted into sets under - // test, and is chosen to be large enough to ensure interesting amounts of - // tree rebalancing. - // - // Note that because checkSet is called between each insertion/removal in - // some tests that use it, tests may be quadratic in testSize. - testSize = 8000 - - // valueOffset is the difference between the value and start of test - // segments. - valueOffset = 100000 - - // intervalLength is the interval used by random gap tests. - intervalLength = 10 -) - -func shuffle(xs []int) { - rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] }) -} - -func randIntervalPermutation(size int) []int { - p := make([]int, size) - for i := range p { - p[i] = intervalLength * i - } - shuffle(p) - return p -} - -// validate can be passed to Check. -func validate(nr int, r Range, v int) error { - if got, want := v, r.Start+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want) - } - return nil -} - -// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well -// maintained. -func checkSetMaxGap(s *gapSet) error { - n := s.root - return checkNodeMaxGap(&n) -} - -// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is -// not well maintained. -func checkNodeMaxGap(n *gapnode) error { - var max int - if !n.hasChildren { - max = n.calculateMaxGapLeaf() - } else { - for i := 0; i <= n.nrSegments; i++ { - child := n.children[i] - if err := checkNodeMaxGap(child); err != nil { - return err - } - if temp := child.maxGap.Get(); i == 0 || temp > max { - max = temp - } - } - } - if max != n.maxGap.Get() { - return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap) - } - return nil -} - -func TestAddRandom(t *testing.T) { - var s Set - order := rand.Perm(testSize) - var nrInsertions int - for i, j := range order { - if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - nrInsertions++ - if err := s.segmentTestCheck(nrInsertions, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := s.countSegments(), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Insertion order: %v", order[:nrInsertions]) - t.Logf("Set contents:\n%v", &s) - } -} - -func TestRemoveRandom(t *testing.T) { - var s Set - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - } - order := rand.Perm(testSize) - var nrRemovals int - for i, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - t.Errorf("Iteration %d: failed to find segment with key %d", i, j) - break - } - s.Remove(seg) - nrRemovals++ - if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := s.countSegments(), testSize-nrRemovals; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Removal order: %v", order[:nrRemovals]) - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestMaxGapAddRandom(t *testing.T) { - var s gapSet - order := rand.Perm(testSize) - var nrInsertions int - for i, j := range order { - if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - nrInsertions++ - if err := s.segmentTestCheck(nrInsertions, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When inserting %d: %v", j, err) - break - } - } - if got, want := s.countSegments(), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Insertion order: %v", order[:nrInsertions]) - t.Logf("Set contents:\n%v", &s) - } -} - -func TestMaxGapAddRandomWithRandomInterval(t *testing.T) { - var s gapSet - order := randIntervalPermutation(testSize) - var nrInsertions int - for i, j := range order { - if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - nrInsertions++ - if err := s.segmentTestCheck(nrInsertions, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When inserting %d: %v", j, err) - break - } - } - if got, want := s.countSegments(), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Insertion order: %v", order[:nrInsertions]) - t.Logf("Set contents:\n%v", &s) - } -} - -func TestMaxGapAddRandomWithMerge(t *testing.T) { - var s gapSet - order := randIntervalPermutation(testSize) - nrInsertions := 1 - for i, j := range order { - if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When inserting %d: %v", j, err) - break - } - } - if got, want := s.countSegments(), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Insertion order: %v", order) - t.Logf("Set contents:\n%v", &s) - } -} - -func TestMaxGapRemoveRandom(t *testing.T) { - var s gapSet - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - } - order := rand.Perm(testSize) - var nrRemovals int - for i, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - t.Errorf("Iteration %d: failed to find segment with key %d", i, j) - break - } - temprange := seg.Range() - s.Remove(seg) - nrRemovals++ - if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When removing %v: %v", temprange, err) - break - } - } - if got, want := s.countSegments(), testSize-nrRemovals; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Removal order: %v", order[:nrRemovals]) - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestMaxGapRemoveHalfRandom(t *testing.T) { - var s gapSet - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - } - order := randIntervalPermutation(testSize) - order = order[:testSize/2] - var nrRemovals int - for i, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - t.Errorf("Iteration %d: failed to find segment with key %d", i, j) - break - } - temprange := seg.Range() - s.Remove(seg) - nrRemovals++ - if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When removing %v: %v", temprange, err) - break - } - } - if got, want := s.countSegments(), testSize-nrRemovals; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Removal order: %v", order[:nrRemovals]) - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) { - var s gapSet - order := randIntervalPermutation(testSize * 2) - order = order[:testSize] - for i, j := range order { - if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When inserting %d: %v", j, err) - break - } - } - shuffle(order) - var nrRemovals int - for _, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - continue - } - temprange := seg.Range() - s.Remove(seg) - nrRemovals++ - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When removing %v: %v", temprange, err) - break - } - } - if t.Failed() { - t.Logf("Removal order: %v", order[:nrRemovals]) - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestNextLargeEnoughGap(t *testing.T) { - var s gapSet - order := randIntervalPermutation(testSize * 2) - order = order[:testSize] - for i, j := range order { - if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When inserting %d: %v", j, err) - break - } - } - shuffle(order) - order = order[:testSize/2] - for _, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - continue - } - temprange := seg.Range() - s.Remove(seg) - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When removing %v: %v", temprange, err) - break - } - } - minSize := 7 - var gapArr1 []int - for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) { - if gap.Range().Length() < minSize { - t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize) - } else { - gapArr1 = append(gapArr1, gap.Range().Start) - } - } - var gapArr2 []int - for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() { - if gap.Range().Length() >= minSize { - gapArr2 = append(gapArr2, gap.Range().Start) - } - } - - if !reflect.DeepEqual(gapArr2, gapArr1) { - t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestPrevLargeEnoughGap(t *testing.T) { - var s gapSet - order := randIntervalPermutation(testSize * 2) - order = order[:testSize] - for i, j := range order { - if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When inserting %d: %v", j, err) - break - } - } - end := s.LastSegment().End() - shuffle(order) - order = order[:testSize/2] - for _, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - continue - } - temprange := seg.Range() - s.Remove(seg) - if err := checkSetMaxGap(&s); err != nil { - t.Errorf("When removing %v: %v", temprange, err) - break - } - } - minSize := 7 - var gapArr1 []int - for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) { - if gap.Range().Length() < minSize { - t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize) - } else { - gapArr1 = append(gapArr1, gap.Range().Start) - } - } - var gapArr2 []int - for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() { - if gap.Range().Length() >= minSize { - gapArr2 = append(gapArr2, gap.Range().Start) - } - } - if !reflect.DeepEqual(gapArr2, gapArr1) { - t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestAddSequentialAdjacent(t *testing.T) { - var s Set - var nrInsertions int - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - nrInsertions++ - if err := s.segmentTestCheck(nrInsertions, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := s.countSegments(), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - } - - first := s.FirstSegment() - gotSeg, gotGap := first.PrevNonEmpty() - if wantGap := s.FirstGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("FirstSegment().PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap) - } - gotSeg, gotGap = first.NextNonEmpty() - if wantSeg := first.NextSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("FirstSegment().NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg) - } - - last := s.LastSegment() - gotSeg, gotGap = last.PrevNonEmpty() - if wantSeg := last.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("LastSegment().PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg) - } - gotSeg, gotGap = last.NextNonEmpty() - if wantGap := s.LastGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("LastSegment().NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap) - } - - for seg := first.NextSegment(); seg != last; seg = seg.NextSegment() { - gotSeg, gotGap = seg.PrevNonEmpty() - if wantSeg := seg.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg) - } - gotSeg, gotGap = seg.NextNonEmpty() - if wantSeg := seg.NextSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg) - } - } -} - -func TestAddSequentialNonAdjacent(t *testing.T) { - var s Set - var nrInsertions int - for i := 0; i < testSize; i++ { - // The range here differs from TestAddSequentialAdjacent so that - // consecutive segments are not adjacent. - if !s.AddWithoutMerging(Range{2 * i, 2*i + 1}, 2*i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - nrInsertions++ - if err := s.segmentTestCheck(nrInsertions, validate); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := s.countSegments(), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - } - - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - gotSeg, gotGap := seg.PrevNonEmpty() - if wantGap := seg.PrevGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap) - } - gotSeg, gotGap = seg.NextNonEmpty() - if wantGap := seg.NextGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap) - } - } -} - -func TestMergeSplit(t *testing.T) { - tests := []struct { - name string - initial []Range - split bool - splitAddr int - final []Range - }{ - { - name: "Add merges after existing segment", - initial: []Range{{1000, 1100}, {1100, 1200}}, - final: []Range{{1000, 1200}}, - }, - { - name: "Add merges before existing segment", - initial: []Range{{1100, 1200}, {1000, 1100}}, - final: []Range{{1000, 1200}}, - }, - { - name: "Add merges between existing segments", - initial: []Range{{1000, 1100}, {1200, 1300}, {1100, 1200}}, - final: []Range{{1000, 1300}}, - }, - { - name: "SplitAt does nothing at a free address", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 300, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt does nothing at the beginning of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 100, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt does nothing at the end of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 200, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt splits in the middle of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 150, - final: []Range{{100, 150}, {150, 200}}, - }, - } -Tests: - for _, test := range tests { - var s Set - for _, r := range test.initial { - if !s.Add(r, 0) { - t.Errorf("%s: Add(%v) failed; set contents:\n%v", test.name, r, &s) - continue Tests - } - } - if test.split { - s.SplitAt(test.splitAddr) - } - var i int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) - continue Tests - } - if got, want := seg.Range(), test.final[i]; got != want { - t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s) - continue Tests - } - i++ - } - if i < len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s) - } - } -} - -func TestIsolate(t *testing.T) { - tests := []struct { - name string - initial Range - bounds Range - final []Range - }{ - { - name: "Isolate does not split a segment that falls inside bounds", - initial: Range{100, 200}, - bounds: Range{100, 200}, - final: []Range{{100, 200}}, - }, - { - name: "Isolate splits at beginning of segment", - initial: Range{50, 200}, - bounds: Range{100, 200}, - final: []Range{{50, 100}, {100, 200}}, - }, - { - name: "Isolate splits at end of segment", - initial: Range{100, 250}, - bounds: Range{100, 200}, - final: []Range{{100, 200}, {200, 250}}, - }, - { - name: "Isolate splits at beginning and end of segment", - initial: Range{50, 250}, - bounds: Range{100, 200}, - final: []Range{{50, 100}, {100, 200}, {200, 250}}, - }, - } -Tests: - for _, test := range tests { - var s Set - seg := s.Insert(s.FirstGap(), test.initial, 0) - seg = s.Isolate(seg, test.bounds) - if !test.bounds.IsSupersetOf(seg.Range()) { - t.Errorf("%s: Isolated segment %v lies outside bounds %v; set contents:\n%v", test.name, seg.Range(), test.bounds, &s) - } - var i int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) - continue Tests - } - if got, want := seg.Range(), test.final[i]; got != want { - t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s) - continue Tests - } - i++ - } - if i < len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s) - } - } -} - -func benchmarkAddSequential(b *testing.B, size int) { - for n := 0; n < b.N; n++ { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - } -} - -func benchmarkAddRandom(b *testing.B, size int) { - order := rand.Perm(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var s Set - for _, i := range order { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - } -} - -func benchmarkFindSequential(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - for i := 0; i < size; i++ { - if seg := s.FindSegment(i); !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - } - } -} - -func benchmarkFindRandom(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - order := rand.Perm(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - for _, i := range order { - if si := s.FindSegment(i); !si.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - } - } -} - -func benchmarkIteration(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - - b.ResetTimer() - var count uint64 - for n := 0; n < b.N; n++ { - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ - } - } - if got, want := count, uint64(size)*uint64(b.N); got != want { - b.Fatalf("Iterated wrong number of segments: got %d, wanted %d", got, want) - } -} - -func benchmarkAddFindRemoveSequential(b *testing.B, size int) { - for n := 0; n < b.N; n++ { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - for i := 0; i < size; i++ { - seg := s.FindSegment(i) - if !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - s.Remove(seg) - } - if !s.IsEmpty() { - b.Fatalf("Set not empty after all removals:\n%v", &s) - } - } -} - -func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := rand.Perm(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var s Set - for _, i := range order { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - for _, i := range order { - seg := s.FindSegment(i) - if !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - s.Remove(seg) - } - if !s.IsEmpty() { - b.Fatalf("Set not empty after all removals:\n%v", &s) - } - } -} - -// Although we don't generally expect our segment sets to get this big, they're -// useful for emulating the effect of cache pressure. -var testSizes = []struct { - desc string - size int -}{ - {"64", 1 << 6}, - {"256", 1 << 8}, - {"1K", 1 << 10}, - {"4K", 1 << 12}, - {"16K", 1 << 14}, - {"64K", 1 << 16}, -} - -func BenchmarkAddSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddSequential(b, test.size) - }) - } -} - -func BenchmarkAddRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddRandom(b, test.size) - }) - } -} - -func BenchmarkFindSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkFindSequential(b, test.size) - }) - } -} - -func BenchmarkFindRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkFindRandom(b, test.size) - }) - } -} - -func BenchmarkIteration(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkIteration(b, test.size) - }) - } -} - -func BenchmarkAddFindRemoveSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddFindRemoveSequential(b, test.size) - }) - } -} - -func BenchmarkAddFindRemoveRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddFindRemoveRandom(b, test.size) - }) - } -} diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go deleted file mode 100644 index 652c010da..000000000 --- a/pkg/segment/test/set_functions.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package segment is a test package. -package segment - -type setFunctions struct{} - -// MinKey returns the minimum key for the set. -func (s setFunctions) MinKey() int { - return -s.MaxKey() - 1 -} - -// MaxKey returns the maximum key for the set. -func (setFunctions) MaxKey() int { - return int(^uint(0) >> 1) -} - -func (setFunctions) ClearValue(*int) {} - -func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { - return val1, true -} - -func (setFunctions) Split(_ Range, val int, _ int) (int, int) { - return val, val -} - -type gapSetFunctions struct { - setFunctions -} - -// MinKey is adjusted to make sure no add overflow would happen in test cases. -// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length(). -// -// Normally Keys should be unsigned to avoid these issues. -func (s gapSetFunctions) MinKey() int { - return s.setFunctions.MinKey() / 2 -} - -// MaxKey returns the maximum key for the set. -func (s gapSetFunctions) MaxKey() int { - return s.setFunctions.MaxKey() / 2 -} diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD deleted file mode 100644 index e759dc36f..000000000 --- a/pkg/sentry/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(licenses = ["notice"]) - -# The "internal" package_group should be used as much as possible by packages -# that should remain Sentry-internal (i.e. not be exposed directly to command -# line tooling or APIs). -package_group( - name = "internal", - packages = [ - "//pkg/sentry/...", - "//runsc/...", - # Code generated by go_marshal relies on go_marshal libraries. - "//tools/go_marshal/...", - ], -) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD deleted file mode 100644 index e0dbc436d..000000000 --- a/pkg/sentry/arch/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "arch", - srcs = [ - "aligned.go", - "arch.go", - "arch_aarch64.go", - "arch_amd64.go", - "arch_arm64.go", - "arch_state_x86.go", - "arch_x86.go", - "arch_x86_impl.go", - "auxv.go", - "signal_amd64.go", - "signal_arm64.go", - "stack.go", - "stack_unsafe.go", - "syscalls_amd64.go", - "syscalls_arm64.go", - ], - marshal = True, - visibility = ["//:sandbox"], - deps = [ - ":registers_go_proto", - "//pkg/abi/linux", - "//pkg/context", - "//pkg/cpuid", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/sentry/arch/fpu", - "//pkg/sentry/limits", - "//pkg/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -proto_library( - name = "registers", - srcs = ["registers.proto"], - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/arch/arch_aarch64_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_aarch64_abi_autogen_unsafe.go new file mode 100644 index 000000000..0a1eae844 --- /dev/null +++ b/pkg/sentry/arch/arch_aarch64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build arm64 +// +build arm64 + +package arch + +import ( +) + diff --git a/pkg/sentry/arch/arch_aarch64_state_autogen.go b/pkg/sentry/arch/arch_aarch64_state_autogen.go new file mode 100644 index 000000000..6faaa9f08 --- /dev/null +++ b/pkg/sentry/arch/arch_aarch64_state_autogen.go @@ -0,0 +1,77 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *Registers) StateTypeName() string { + return "pkg/sentry/arch.Registers" +} + +func (r *Registers) StateFields() []string { + return []string{ + "PtraceRegs", + "TPIDR_EL0", + } +} + +func (r *Registers) beforeSave() {} + +// +checklocksignore +func (r *Registers) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.PtraceRegs) + stateSinkObject.Save(1, &r.TPIDR_EL0) +} + +func (r *Registers) afterLoad() {} + +// +checklocksignore +func (r *Registers) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.PtraceRegs) + stateSourceObject.Load(1, &r.TPIDR_EL0) +} + +func (s *State) StateTypeName() string { + return "pkg/sentry/arch.State" +} + +func (s *State) StateFields() []string { + return []string{ + "Regs", + "fpState", + "FeatureSet", + "OrigR0", + } +} + +func (s *State) beforeSave() {} + +// +checklocksignore +func (s *State) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Regs) + stateSinkObject.Save(1, &s.fpState) + stateSinkObject.Save(2, &s.FeatureSet) + stateSinkObject.Save(3, &s.OrigR0) +} + +func (s *State) afterLoad() {} + +// +checklocksignore +func (s *State) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Regs) + stateSourceObject.LoadWait(1, &s.fpState) + stateSourceObject.Load(2, &s.FeatureSet) + stateSourceObject.Load(3, &s.OrigR0) +} + +func init() { + state.Register((*Registers)(nil)) + state.Register((*State)(nil)) +} diff --git a/pkg/sentry/arch/arch_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_abi_autogen_unsafe.go new file mode 100644 index 000000000..2167ed9c9 --- /dev/null +++ b/pkg/sentry/arch/arch_abi_autogen_unsafe.go @@ -0,0 +1,7 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package arch + +import ( +) + diff --git a/pkg/sentry/arch/arch_amd64_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_amd64_abi_autogen_unsafe.go new file mode 100644 index 000000000..dd850b1b8 --- /dev/null +++ b/pkg/sentry/arch/arch_amd64_abi_autogen_unsafe.go @@ -0,0 +1,411 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*SignalContext64)(nil) +var _ marshal.Marshallable = (*UContext64)(nil) +var _ marshal.Marshallable = (*linux.SignalSet)(nil) +var _ marshal.Marshallable = (*linux.SignalStack)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SignalContext64) SizeBytes() int { + return 184 + + (*linux.SignalSet)(nil).SizeBytes() + + 8*8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalContext64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R8)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R9)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R10)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R11)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R12)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R13)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R14)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.R15)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rdi)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rsi)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rbp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rbx)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rdx)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rax)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rcx)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rsp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Rip)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Eflags)) + dst = dst[8:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Cs)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Gs)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Fs)) + dst = dst[2:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(s.Ss)) + dst = dst[2:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Err)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Trapno)) + dst = dst[8:] + s.Oldmask.MarshalBytes(dst[:s.Oldmask.SizeBytes()]) + dst = dst[s.Oldmask.SizeBytes():] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Cr2)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Fpstate)) + dst = dst[8:] + for idx := 0; idx < 8; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Reserved[idx])) + dst = dst[8:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalContext64) UnmarshalBytes(src []byte) { + s.R8 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R9 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R10 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R11 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R12 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R13 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R14 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.R15 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rdi = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rsi = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rbp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rbx = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rdx = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rax = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rcx = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rsp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Rip = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Eflags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Cs = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Gs = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Fs = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Ss = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + s.Err = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Trapno = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Oldmask.UnmarshalBytes(src[:s.Oldmask.SizeBytes()]) + src = src[s.Oldmask.SizeBytes():] + s.Cr2 = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Fpstate = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 8; idx++ { + s.Reserved[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SignalContext64) Packed() bool { + return s.Oldmask.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalContext64) MarshalUnsafe(dst []byte) { + if s.Oldmask.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SignalContext64 doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalContext64) UnmarshalUnsafe(src []byte) { + if s.Oldmask.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SignalContext64 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SignalContext64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Oldmask.Packed() { + // Type SignalContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SignalContext64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SignalContext64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Oldmask.Packed() { + // Type SignalContext64 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalContext64) WriteTo(writer io.Writer) (int64, error) { + if !s.Oldmask.Packed() { + // Type SignalContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *UContext64) SizeBytes() int { + return 16 + + (*linux.SignalStack)(nil).SizeBytes() + + (*SignalContext64)(nil).SizeBytes() + + (*linux.SignalSet)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *UContext64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Flags)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Link)) + dst = dst[8:] + u.Stack.MarshalBytes(dst[:u.Stack.SizeBytes()]) + dst = dst[u.Stack.SizeBytes():] + u.MContext.MarshalBytes(dst[:u.MContext.SizeBytes()]) + dst = dst[u.MContext.SizeBytes():] + u.Sigset.MarshalBytes(dst[:u.Sigset.SizeBytes()]) + dst = dst[u.Sigset.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *UContext64) UnmarshalBytes(src []byte) { + u.Flags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + u.Link = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + u.Stack.UnmarshalBytes(src[:u.Stack.SizeBytes()]) + src = src[u.Stack.SizeBytes():] + u.MContext.UnmarshalBytes(src[:u.MContext.SizeBytes()]) + src = src[u.MContext.SizeBytes():] + u.Sigset.UnmarshalBytes(src[:u.Sigset.SizeBytes()]) + src = src[u.Sigset.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *UContext64) Packed() bool { + return u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *UContext64) MarshalUnsafe(dst []byte) { + if u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) + } else { + // Type UContext64 doesn't have a packed layout in memory, fallback to MarshalBytes. + u.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *UContext64) UnmarshalUnsafe(src []byte) { + if u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) + } else { + // Type UContext64 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + u.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *UContext64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + // Type UContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(u.SizeBytes()) // escapes: okay. + u.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *UContext64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *UContext64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + // Type UContext64 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(u.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + u.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *UContext64) WriteTo(writer io.Writer) (int64, error) { + if !u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + // Type UContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, u.SizeBytes()) + u.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/arch/arch_amd64_state_autogen.go b/pkg/sentry/arch/arch_amd64_state_autogen.go new file mode 100644 index 000000000..501572f2c --- /dev/null +++ b/pkg/sentry/arch/arch_amd64_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (c *context64) StateTypeName() string { + return "pkg/sentry/arch.context64" +} + +func (c *context64) StateFields() []string { + return []string{ + "State", + "sigFPState", + } +} + +func (c *context64) beforeSave() {} + +// +checklocksignore +func (c *context64) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.State) + stateSinkObject.Save(1, &c.sigFPState) +} + +func (c *context64) afterLoad() {} + +// +checklocksignore +func (c *context64) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.State) + stateSourceObject.Load(1, &c.sigFPState) +} + +func init() { + state.Register((*context64)(nil)) +} diff --git a/pkg/sentry/arch/arch_arm64_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_arm64_abi_autogen_unsafe.go new file mode 100644 index 000000000..52ca0b185 --- /dev/null +++ b/pkg/sentry/arch/arch_arm64_abi_autogen_unsafe.go @@ -0,0 +1,587 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*FpsimdContext)(nil) +var _ marshal.Marshallable = (*SignalContext64)(nil) +var _ marshal.Marshallable = (*UContext64)(nil) +var _ marshal.Marshallable = (*aarch64Ctx)(nil) +var _ marshal.Marshallable = (*linux.SignalSet)(nil) +var _ marshal.Marshallable = (*linux.SignalStack)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (f *FpsimdContext) SizeBytes() int { + return 8 + + (*aarch64Ctx)(nil).SizeBytes() + + 8*64 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (f *FpsimdContext) MarshalBytes(dst []byte) { + f.Head.MarshalBytes(dst[:f.Head.SizeBytes()]) + dst = dst[f.Head.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Fpsr)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(f.Fpcr)) + dst = dst[4:] + for idx := 0; idx < 64; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(f.Vregs[idx])) + dst = dst[8:] + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (f *FpsimdContext) UnmarshalBytes(src []byte) { + f.Head.UnmarshalBytes(src[:f.Head.SizeBytes()]) + src = src[f.Head.SizeBytes():] + f.Fpsr = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + f.Fpcr = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + for idx := 0; idx < 64; idx++ { + f.Vregs[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (f *FpsimdContext) Packed() bool { + return f.Head.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (f *FpsimdContext) MarshalUnsafe(dst []byte) { + if f.Head.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(f), uintptr(f.SizeBytes())) + } else { + // Type FpsimdContext doesn't have a packed layout in memory, fallback to MarshalBytes. + f.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (f *FpsimdContext) UnmarshalUnsafe(src []byte) { + if f.Head.Packed() { + gohacks.Memmove(unsafe.Pointer(f), unsafe.Pointer(&src[0]), uintptr(f.SizeBytes())) + } else { + // Type FpsimdContext doesn't have a packed layout in memory, fallback to UnmarshalBytes. + f.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (f *FpsimdContext) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !f.Head.Packed() { + // Type FpsimdContext doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + f.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (f *FpsimdContext) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return f.CopyOutN(cc, addr, f.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (f *FpsimdContext) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !f.Head.Packed() { + // Type FpsimdContext doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(f.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + f.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (f *FpsimdContext) WriteTo(writer io.Writer) (int64, error) { + if !f.Head.Packed() { + // Type FpsimdContext doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, f.SizeBytes()) + f.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(f))) + hdr.Len = f.SizeBytes() + hdr.Cap = f.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that f + // must live until the use above. + runtime.KeepAlive(f) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SignalContext64) SizeBytes() int { + return 32 + + 8*31 + + 1*8 + + (*FpsimdContext)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SignalContext64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.FaultAddr)) + dst = dst[8:] + for idx := 0; idx < 31; idx++ { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Regs[idx])) + dst = dst[8:] + } + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Sp)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Pc)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.Pstate)) + dst = dst[8:] + for idx := 0; idx < 8; idx++ { + dst[0] = byte(s._pad[idx]) + dst = dst[1:] + } + s.Fpsimd64.MarshalBytes(dst[:s.Fpsimd64.SizeBytes()]) + dst = dst[s.Fpsimd64.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SignalContext64) UnmarshalBytes(src []byte) { + s.FaultAddr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 31; idx++ { + s.Regs[idx] = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + } + s.Sp = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Pc = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.Pstate = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + for idx := 0; idx < 8; idx++ { + s._pad[idx] = src[0] + src = src[1:] + } + s.Fpsimd64.UnmarshalBytes(src[:s.Fpsimd64.SizeBytes()]) + src = src[s.Fpsimd64.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SignalContext64) Packed() bool { + return s.Fpsimd64.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SignalContext64) MarshalUnsafe(dst []byte) { + if s.Fpsimd64.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) + } else { + // Type SignalContext64 doesn't have a packed layout in memory, fallback to MarshalBytes. + s.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SignalContext64) UnmarshalUnsafe(src []byte) { + if s.Fpsimd64.Packed() { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) + } else { + // Type SignalContext64 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + s.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SignalContext64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !s.Fpsimd64.Packed() { + // Type SignalContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + s.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SignalContext64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SignalContext64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !s.Fpsimd64.Packed() { + // Type SignalContext64 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + s.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SignalContext64) WriteTo(writer io.Writer) (int64, error) { + if !s.Fpsimd64.Packed() { + // Type SignalContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, s.SizeBytes()) + s.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *UContext64) SizeBytes() int { + return 16 + + (*linux.SignalStack)(nil).SizeBytes() + + (*linux.SignalSet)(nil).SizeBytes() + + 1*120 + + 1*8 + + (*SignalContext64)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *UContext64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Flags)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Link)) + dst = dst[8:] + u.Stack.MarshalBytes(dst[:u.Stack.SizeBytes()]) + dst = dst[u.Stack.SizeBytes():] + u.Sigset.MarshalBytes(dst[:u.Sigset.SizeBytes()]) + dst = dst[u.Sigset.SizeBytes():] + for idx := 0; idx < 120; idx++ { + dst[0] = byte(u._pad[idx]) + dst = dst[1:] + } + for idx := 0; idx < 8; idx++ { + dst[0] = byte(u._pad2[idx]) + dst = dst[1:] + } + u.MContext.MarshalBytes(dst[:u.MContext.SizeBytes()]) + dst = dst[u.MContext.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *UContext64) UnmarshalBytes(src []byte) { + u.Flags = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + u.Link = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + u.Stack.UnmarshalBytes(src[:u.Stack.SizeBytes()]) + src = src[u.Stack.SizeBytes():] + u.Sigset.UnmarshalBytes(src[:u.Sigset.SizeBytes()]) + src = src[u.Sigset.SizeBytes():] + for idx := 0; idx < 120; idx++ { + u._pad[idx] = src[0] + src = src[1:] + } + for idx := 0; idx < 8; idx++ { + u._pad2[idx] = src[0] + src = src[1:] + } + u.MContext.UnmarshalBytes(src[:u.MContext.SizeBytes()]) + src = src[u.MContext.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *UContext64) Packed() bool { + return u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *UContext64) MarshalUnsafe(dst []byte) { + if u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) + } else { + // Type UContext64 doesn't have a packed layout in memory, fallback to MarshalBytes. + u.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *UContext64) UnmarshalUnsafe(src []byte) { + if u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) + } else { + // Type UContext64 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + u.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *UContext64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + // Type UContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(u.SizeBytes()) // escapes: okay. + u.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *UContext64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *UContext64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + // Type UContext64 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(u.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + u.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *UContext64) WriteTo(writer io.Writer) (int64, error) { + if !u.MContext.Packed() && u.Sigset.Packed() && u.Stack.Packed() { + // Type UContext64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, u.SizeBytes()) + u.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (a *aarch64Ctx) SizeBytes() int { + return 8 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (a *aarch64Ctx) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(a.Magic)) + dst = dst[4:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(a.Size)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (a *aarch64Ctx) UnmarshalBytes(src []byte) { + a.Magic = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + a.Size = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (a *aarch64Ctx) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (a *aarch64Ctx) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(a), uintptr(a.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (a *aarch64Ctx) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(a), unsafe.Pointer(&src[0]), uintptr(a.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (a *aarch64Ctx) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(a))) + hdr.Len = a.SizeBytes() + hdr.Cap = a.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that a + // must live until the use above. + runtime.KeepAlive(a) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (a *aarch64Ctx) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return a.CopyOutN(cc, addr, a.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (a *aarch64Ctx) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(a))) + hdr.Len = a.SizeBytes() + hdr.Cap = a.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that a + // must live until the use above. + runtime.KeepAlive(a) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (a *aarch64Ctx) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(a))) + hdr.Len = a.SizeBytes() + hdr.Cap = a.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that a + // must live until the use above. + runtime.KeepAlive(a) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/arch/arch_arm64_state_autogen.go b/pkg/sentry/arch/arch_arm64_state_autogen.go new file mode 100644 index 000000000..45bfbfca1 --- /dev/null +++ b/pkg/sentry/arch/arch_arm64_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (c *context64) StateTypeName() string { + return "pkg/sentry/arch.context64" +} + +func (c *context64) StateFields() []string { + return []string{ + "State", + "sigFPState", + } +} + +func (c *context64) beforeSave() {} + +// +checklocksignore +func (c *context64) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.State) + stateSinkObject.Save(1, &c.sigFPState) +} + +func (c *context64) afterLoad() {} + +// +checklocksignore +func (c *context64) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.State) + stateSourceObject.Load(1, &c.sigFPState) +} + +func init() { + state.Register((*context64)(nil)) +} diff --git a/pkg/sentry/arch/arch_state_autogen.go b/pkg/sentry/arch/arch_state_autogen.go new file mode 100644 index 000000000..8a151f95c --- /dev/null +++ b/pkg/sentry/arch/arch_state_autogen.go @@ -0,0 +1,80 @@ +// automatically generated by stateify. + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (m *MmapLayout) StateTypeName() string { + return "pkg/sentry/arch.MmapLayout" +} + +func (m *MmapLayout) StateFields() []string { + return []string{ + "MinAddr", + "MaxAddr", + "BottomUpBase", + "TopDownBase", + "DefaultDirection", + "MaxStackRand", + } +} + +func (m *MmapLayout) beforeSave() {} + +// +checklocksignore +func (m *MmapLayout) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.MinAddr) + stateSinkObject.Save(1, &m.MaxAddr) + stateSinkObject.Save(2, &m.BottomUpBase) + stateSinkObject.Save(3, &m.TopDownBase) + stateSinkObject.Save(4, &m.DefaultDirection) + stateSinkObject.Save(5, &m.MaxStackRand) +} + +func (m *MmapLayout) afterLoad() {} + +// +checklocksignore +func (m *MmapLayout) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.MinAddr) + stateSourceObject.Load(1, &m.MaxAddr) + stateSourceObject.Load(2, &m.BottomUpBase) + stateSourceObject.Load(3, &m.TopDownBase) + stateSourceObject.Load(4, &m.DefaultDirection) + stateSourceObject.Load(5, &m.MaxStackRand) +} + +func (a *AuxEntry) StateTypeName() string { + return "pkg/sentry/arch.AuxEntry" +} + +func (a *AuxEntry) StateFields() []string { + return []string{ + "Key", + "Value", + } +} + +func (a *AuxEntry) beforeSave() {} + +// +checklocksignore +func (a *AuxEntry) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.Key) + stateSinkObject.Save(1, &a.Value) +} + +func (a *AuxEntry) afterLoad() {} + +// +checklocksignore +func (a *AuxEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.Key) + stateSourceObject.Load(1, &a.Value) +} + +func init() { + state.Register((*MmapLayout)(nil)) + state.Register((*AuxEntry)(nil)) +} diff --git a/pkg/sentry/arch/arch_unsafe_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_unsafe_abi_autogen_unsafe.go new file mode 100644 index 000000000..2167ed9c9 --- /dev/null +++ b/pkg/sentry/arch/arch_unsafe_abi_autogen_unsafe.go @@ -0,0 +1,7 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package arch + +import ( +) + diff --git a/pkg/sentry/arch/arch_unsafe_state_autogen.go b/pkg/sentry/arch/arch_unsafe_state_autogen.go new file mode 100644 index 000000000..9a45071b9 --- /dev/null +++ b/pkg/sentry/arch/arch_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package arch diff --git a/pkg/sentry/arch/arch_x86_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_x86_abi_autogen_unsafe.go new file mode 100644 index 000000000..97a463c9f --- /dev/null +++ b/pkg/sentry/arch/arch_x86_abi_autogen_unsafe.go @@ -0,0 +1,17 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build (amd64 || 386) && (amd64 || 386) +// +build amd64 386 +// +build amd64 386 + +package arch + +import ( +) + diff --git a/pkg/sentry/arch/arch_x86_impl_abi_autogen_unsafe.go b/pkg/sentry/arch/arch_x86_impl_abi_autogen_unsafe.go new file mode 100644 index 000000000..2b2c59d95 --- /dev/null +++ b/pkg/sentry/arch/arch_x86_impl_abi_autogen_unsafe.go @@ -0,0 +1,17 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build (amd64 || 386) && go1.1 +// +build amd64 386 +// +build go1.1 + +package arch + +import ( +) + diff --git a/pkg/sentry/arch/arch_x86_impl_state_autogen.go b/pkg/sentry/arch/arch_x86_impl_state_autogen.go new file mode 100644 index 000000000..edcc2bee5 --- /dev/null +++ b/pkg/sentry/arch/arch_x86_impl_state_autogen.go @@ -0,0 +1,45 @@ +// automatically generated by stateify. + +//go:build (amd64 || 386) && go1.1 +// +build amd64 386 +// +build go1.1 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *State) StateTypeName() string { + return "pkg/sentry/arch.State" +} + +func (s *State) StateFields() []string { + return []string{ + "Regs", + "fpState", + "FeatureSet", + } +} + +func (s *State) beforeSave() {} + +// +checklocksignore +func (s *State) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Regs) + stateSinkObject.Save(1, &s.fpState) + stateSinkObject.Save(2, &s.FeatureSet) +} + +// +checklocksignore +func (s *State) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Regs) + stateSourceObject.LoadWait(1, &s.fpState) + stateSourceObject.Load(2, &s.FeatureSet) + stateSourceObject.AfterLoad(s.afterLoad) +} + +func init() { + state.Register((*State)(nil)) +} diff --git a/pkg/sentry/arch/arch_x86_state_autogen.go b/pkg/sentry/arch/arch_x86_state_autogen.go new file mode 100644 index 000000000..7785fb645 --- /dev/null +++ b/pkg/sentry/arch/arch_x86_state_autogen.go @@ -0,0 +1,40 @@ +// automatically generated by stateify. + +//go:build (amd64 || 386) && (amd64 || 386) +// +build amd64 386 +// +build amd64 386 + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *Registers) StateTypeName() string { + return "pkg/sentry/arch.Registers" +} + +func (r *Registers) StateFields() []string { + return []string{ + "PtraceRegs", + } +} + +func (r *Registers) beforeSave() {} + +// +checklocksignore +func (r *Registers) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.PtraceRegs) +} + +func (r *Registers) afterLoad() {} + +// +checklocksignore +func (r *Registers) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.PtraceRegs) +} + +func init() { + state.Register((*Registers)(nil)) +} diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD deleted file mode 100644 index 6cdd21b1b..000000000 --- a/pkg/sentry/arch/fpu/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fpu", - srcs = [ - "fpu.go", - "fpu_amd64.go", - "fpu_amd64.s", - "fpu_arm64.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/cpuid", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/arch/fpu/fpu_amd64_state_autogen.go b/pkg/sentry/arch/fpu/fpu_amd64_state_autogen.go new file mode 100644 index 000000000..d31975f85 --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 || i386 +// +build amd64 i386 + +package fpu diff --git a/pkg/sentry/arch/fpu/fpu_arm64_state_autogen.go b/pkg/sentry/arch/fpu/fpu_arm64_state_autogen.go new file mode 100644 index 000000000..7b28aba81 --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package fpu diff --git a/pkg/sentry/arch/fpu/fpu_state_autogen.go b/pkg/sentry/arch/fpu/fpu_state_autogen.go new file mode 100644 index 000000000..7cadc6b8f --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fpu diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto deleted file mode 100644 index 2727ba08a..000000000 --- a/pkg/sentry/arch/registers.proto +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -message AMD64Registers { - uint64 rax = 1; - uint64 rbx = 2; - uint64 rcx = 3; - uint64 rdx = 4; - uint64 rsi = 5; - uint64 rdi = 6; - uint64 rsp = 7; - uint64 rbp = 8; - - uint64 r8 = 9; - uint64 r9 = 10; - uint64 r10 = 11; - uint64 r11 = 12; - uint64 r12 = 13; - uint64 r13 = 14; - uint64 r14 = 15; - uint64 r15 = 16; - - uint64 rip = 17; - uint64 rflags = 18; - uint64 orig_rax = 19; - uint64 cs = 20; - uint64 ds = 21; - uint64 es = 22; - uint64 fs = 23; - uint64 gs = 24; - uint64 ss = 25; - uint64 fs_base = 26; - uint64 gs_base = 27; -} - -message ARM64Registers { - uint64 r0 = 1; - uint64 r1 = 2; - uint64 r2 = 3; - uint64 r3 = 4; - uint64 r4 = 5; - uint64 r5 = 6; - uint64 r6 = 7; - uint64 r7 = 8; - uint64 r8 = 9; - uint64 r9 = 10; - uint64 r10 = 11; - uint64 r11 = 12; - uint64 r12 = 13; - uint64 r13 = 14; - uint64 r14 = 15; - uint64 r15 = 16; - uint64 r16 = 17; - uint64 r17 = 18; - uint64 r18 = 19; - uint64 r19 = 20; - uint64 r20 = 21; - uint64 r21 = 22; - uint64 r22 = 23; - uint64 r23 = 24; - uint64 r24 = 25; - uint64 r25 = 26; - uint64 r26 = 27; - uint64 r27 = 28; - uint64 r28 = 29; - uint64 r29 = 30; - uint64 r30 = 31; - uint64 sp = 32; - uint64 pc = 33; - uint64 pstate = 34; - uint64 tls = 35; -} -message Registers { - oneof arch { - AMD64Registers amd64 = 1; - ARM64Registers arm64 = 2; - } -} diff --git a/pkg/sentry/arch/registers_go_proto/registers.pb.go b/pkg/sentry/arch/registers_go_proto/registers.pb.go new file mode 100644 index 000000000..0e73a10b8 --- /dev/null +++ b/pkg/sentry/arch/registers_go_proto/registers.pb.go @@ -0,0 +1,863 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/sentry/arch/registers.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type AMD64Registers struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Rax uint64 `protobuf:"varint,1,opt,name=rax,proto3" json:"rax,omitempty"` + Rbx uint64 `protobuf:"varint,2,opt,name=rbx,proto3" json:"rbx,omitempty"` + Rcx uint64 `protobuf:"varint,3,opt,name=rcx,proto3" json:"rcx,omitempty"` + Rdx uint64 `protobuf:"varint,4,opt,name=rdx,proto3" json:"rdx,omitempty"` + Rsi uint64 `protobuf:"varint,5,opt,name=rsi,proto3" json:"rsi,omitempty"` + Rdi uint64 `protobuf:"varint,6,opt,name=rdi,proto3" json:"rdi,omitempty"` + Rsp uint64 `protobuf:"varint,7,opt,name=rsp,proto3" json:"rsp,omitempty"` + Rbp uint64 `protobuf:"varint,8,opt,name=rbp,proto3" json:"rbp,omitempty"` + R8 uint64 `protobuf:"varint,9,opt,name=r8,proto3" json:"r8,omitempty"` + R9 uint64 `protobuf:"varint,10,opt,name=r9,proto3" json:"r9,omitempty"` + R10 uint64 `protobuf:"varint,11,opt,name=r10,proto3" json:"r10,omitempty"` + R11 uint64 `protobuf:"varint,12,opt,name=r11,proto3" json:"r11,omitempty"` + R12 uint64 `protobuf:"varint,13,opt,name=r12,proto3" json:"r12,omitempty"` + R13 uint64 `protobuf:"varint,14,opt,name=r13,proto3" json:"r13,omitempty"` + R14 uint64 `protobuf:"varint,15,opt,name=r14,proto3" json:"r14,omitempty"` + R15 uint64 `protobuf:"varint,16,opt,name=r15,proto3" json:"r15,omitempty"` + Rip uint64 `protobuf:"varint,17,opt,name=rip,proto3" json:"rip,omitempty"` + Rflags uint64 `protobuf:"varint,18,opt,name=rflags,proto3" json:"rflags,omitempty"` + OrigRax uint64 `protobuf:"varint,19,opt,name=orig_rax,json=origRax,proto3" json:"orig_rax,omitempty"` + Cs uint64 `protobuf:"varint,20,opt,name=cs,proto3" json:"cs,omitempty"` + Ds uint64 `protobuf:"varint,21,opt,name=ds,proto3" json:"ds,omitempty"` + Es uint64 `protobuf:"varint,22,opt,name=es,proto3" json:"es,omitempty"` + Fs uint64 `protobuf:"varint,23,opt,name=fs,proto3" json:"fs,omitempty"` + Gs uint64 `protobuf:"varint,24,opt,name=gs,proto3" json:"gs,omitempty"` + Ss uint64 `protobuf:"varint,25,opt,name=ss,proto3" json:"ss,omitempty"` + FsBase uint64 `protobuf:"varint,26,opt,name=fs_base,json=fsBase,proto3" json:"fs_base,omitempty"` + GsBase uint64 `protobuf:"varint,27,opt,name=gs_base,json=gsBase,proto3" json:"gs_base,omitempty"` +} + +func (x *AMD64Registers) Reset() { + *x = AMD64Registers{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_arch_registers_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AMD64Registers) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AMD64Registers) ProtoMessage() {} + +func (x *AMD64Registers) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_arch_registers_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AMD64Registers.ProtoReflect.Descriptor instead. +func (*AMD64Registers) Descriptor() ([]byte, []int) { + return file_pkg_sentry_arch_registers_proto_rawDescGZIP(), []int{0} +} + +func (x *AMD64Registers) GetRax() uint64 { + if x != nil { + return x.Rax + } + return 0 +} + +func (x *AMD64Registers) GetRbx() uint64 { + if x != nil { + return x.Rbx + } + return 0 +} + +func (x *AMD64Registers) GetRcx() uint64 { + if x != nil { + return x.Rcx + } + return 0 +} + +func (x *AMD64Registers) GetRdx() uint64 { + if x != nil { + return x.Rdx + } + return 0 +} + +func (x *AMD64Registers) GetRsi() uint64 { + if x != nil { + return x.Rsi + } + return 0 +} + +func (x *AMD64Registers) GetRdi() uint64 { + if x != nil { + return x.Rdi + } + return 0 +} + +func (x *AMD64Registers) GetRsp() uint64 { + if x != nil { + return x.Rsp + } + return 0 +} + +func (x *AMD64Registers) GetRbp() uint64 { + if x != nil { + return x.Rbp + } + return 0 +} + +func (x *AMD64Registers) GetR8() uint64 { + if x != nil { + return x.R8 + } + return 0 +} + +func (x *AMD64Registers) GetR9() uint64 { + if x != nil { + return x.R9 + } + return 0 +} + +func (x *AMD64Registers) GetR10() uint64 { + if x != nil { + return x.R10 + } + return 0 +} + +func (x *AMD64Registers) GetR11() uint64 { + if x != nil { + return x.R11 + } + return 0 +} + +func (x *AMD64Registers) GetR12() uint64 { + if x != nil { + return x.R12 + } + return 0 +} + +func (x *AMD64Registers) GetR13() uint64 { + if x != nil { + return x.R13 + } + return 0 +} + +func (x *AMD64Registers) GetR14() uint64 { + if x != nil { + return x.R14 + } + return 0 +} + +func (x *AMD64Registers) GetR15() uint64 { + if x != nil { + return x.R15 + } + return 0 +} + +func (x *AMD64Registers) GetRip() uint64 { + if x != nil { + return x.Rip + } + return 0 +} + +func (x *AMD64Registers) GetRflags() uint64 { + if x != nil { + return x.Rflags + } + return 0 +} + +func (x *AMD64Registers) GetOrigRax() uint64 { + if x != nil { + return x.OrigRax + } + return 0 +} + +func (x *AMD64Registers) GetCs() uint64 { + if x != nil { + return x.Cs + } + return 0 +} + +func (x *AMD64Registers) GetDs() uint64 { + if x != nil { + return x.Ds + } + return 0 +} + +func (x *AMD64Registers) GetEs() uint64 { + if x != nil { + return x.Es + } + return 0 +} + +func (x *AMD64Registers) GetFs() uint64 { + if x != nil { + return x.Fs + } + return 0 +} + +func (x *AMD64Registers) GetGs() uint64 { + if x != nil { + return x.Gs + } + return 0 +} + +func (x *AMD64Registers) GetSs() uint64 { + if x != nil { + return x.Ss + } + return 0 +} + +func (x *AMD64Registers) GetFsBase() uint64 { + if x != nil { + return x.FsBase + } + return 0 +} + +func (x *AMD64Registers) GetGsBase() uint64 { + if x != nil { + return x.GsBase + } + return 0 +} + +type ARM64Registers struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + R0 uint64 `protobuf:"varint,1,opt,name=r0,proto3" json:"r0,omitempty"` + R1 uint64 `protobuf:"varint,2,opt,name=r1,proto3" json:"r1,omitempty"` + R2 uint64 `protobuf:"varint,3,opt,name=r2,proto3" json:"r2,omitempty"` + R3 uint64 `protobuf:"varint,4,opt,name=r3,proto3" json:"r3,omitempty"` + R4 uint64 `protobuf:"varint,5,opt,name=r4,proto3" json:"r4,omitempty"` + R5 uint64 `protobuf:"varint,6,opt,name=r5,proto3" json:"r5,omitempty"` + R6 uint64 `protobuf:"varint,7,opt,name=r6,proto3" json:"r6,omitempty"` + R7 uint64 `protobuf:"varint,8,opt,name=r7,proto3" json:"r7,omitempty"` + R8 uint64 `protobuf:"varint,9,opt,name=r8,proto3" json:"r8,omitempty"` + R9 uint64 `protobuf:"varint,10,opt,name=r9,proto3" json:"r9,omitempty"` + R10 uint64 `protobuf:"varint,11,opt,name=r10,proto3" json:"r10,omitempty"` + R11 uint64 `protobuf:"varint,12,opt,name=r11,proto3" json:"r11,omitempty"` + R12 uint64 `protobuf:"varint,13,opt,name=r12,proto3" json:"r12,omitempty"` + R13 uint64 `protobuf:"varint,14,opt,name=r13,proto3" json:"r13,omitempty"` + R14 uint64 `protobuf:"varint,15,opt,name=r14,proto3" json:"r14,omitempty"` + R15 uint64 `protobuf:"varint,16,opt,name=r15,proto3" json:"r15,omitempty"` + R16 uint64 `protobuf:"varint,17,opt,name=r16,proto3" json:"r16,omitempty"` + R17 uint64 `protobuf:"varint,18,opt,name=r17,proto3" json:"r17,omitempty"` + R18 uint64 `protobuf:"varint,19,opt,name=r18,proto3" json:"r18,omitempty"` + R19 uint64 `protobuf:"varint,20,opt,name=r19,proto3" json:"r19,omitempty"` + R20 uint64 `protobuf:"varint,21,opt,name=r20,proto3" json:"r20,omitempty"` + R21 uint64 `protobuf:"varint,22,opt,name=r21,proto3" json:"r21,omitempty"` + R22 uint64 `protobuf:"varint,23,opt,name=r22,proto3" json:"r22,omitempty"` + R23 uint64 `protobuf:"varint,24,opt,name=r23,proto3" json:"r23,omitempty"` + R24 uint64 `protobuf:"varint,25,opt,name=r24,proto3" json:"r24,omitempty"` + R25 uint64 `protobuf:"varint,26,opt,name=r25,proto3" json:"r25,omitempty"` + R26 uint64 `protobuf:"varint,27,opt,name=r26,proto3" json:"r26,omitempty"` + R27 uint64 `protobuf:"varint,28,opt,name=r27,proto3" json:"r27,omitempty"` + R28 uint64 `protobuf:"varint,29,opt,name=r28,proto3" json:"r28,omitempty"` + R29 uint64 `protobuf:"varint,30,opt,name=r29,proto3" json:"r29,omitempty"` + R30 uint64 `protobuf:"varint,31,opt,name=r30,proto3" json:"r30,omitempty"` + Sp uint64 `protobuf:"varint,32,opt,name=sp,proto3" json:"sp,omitempty"` + Pc uint64 `protobuf:"varint,33,opt,name=pc,proto3" json:"pc,omitempty"` + Pstate uint64 `protobuf:"varint,34,opt,name=pstate,proto3" json:"pstate,omitempty"` + Tls uint64 `protobuf:"varint,35,opt,name=tls,proto3" json:"tls,omitempty"` +} + +func (x *ARM64Registers) Reset() { + *x = ARM64Registers{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_arch_registers_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ARM64Registers) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ARM64Registers) ProtoMessage() {} + +func (x *ARM64Registers) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_arch_registers_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ARM64Registers.ProtoReflect.Descriptor instead. +func (*ARM64Registers) Descriptor() ([]byte, []int) { + return file_pkg_sentry_arch_registers_proto_rawDescGZIP(), []int{1} +} + +func (x *ARM64Registers) GetR0() uint64 { + if x != nil { + return x.R0 + } + return 0 +} + +func (x *ARM64Registers) GetR1() uint64 { + if x != nil { + return x.R1 + } + return 0 +} + +func (x *ARM64Registers) GetR2() uint64 { + if x != nil { + return x.R2 + } + return 0 +} + +func (x *ARM64Registers) GetR3() uint64 { + if x != nil { + return x.R3 + } + return 0 +} + +func (x *ARM64Registers) GetR4() uint64 { + if x != nil { + return x.R4 + } + return 0 +} + +func (x *ARM64Registers) GetR5() uint64 { + if x != nil { + return x.R5 + } + return 0 +} + +func (x *ARM64Registers) GetR6() uint64 { + if x != nil { + return x.R6 + } + return 0 +} + +func (x *ARM64Registers) GetR7() uint64 { + if x != nil { + return x.R7 + } + return 0 +} + +func (x *ARM64Registers) GetR8() uint64 { + if x != nil { + return x.R8 + } + return 0 +} + +func (x *ARM64Registers) GetR9() uint64 { + if x != nil { + return x.R9 + } + return 0 +} + +func (x *ARM64Registers) GetR10() uint64 { + if x != nil { + return x.R10 + } + return 0 +} + +func (x *ARM64Registers) GetR11() uint64 { + if x != nil { + return x.R11 + } + return 0 +} + +func (x *ARM64Registers) GetR12() uint64 { + if x != nil { + return x.R12 + } + return 0 +} + +func (x *ARM64Registers) GetR13() uint64 { + if x != nil { + return x.R13 + } + return 0 +} + +func (x *ARM64Registers) GetR14() uint64 { + if x != nil { + return x.R14 + } + return 0 +} + +func (x *ARM64Registers) GetR15() uint64 { + if x != nil { + return x.R15 + } + return 0 +} + +func (x *ARM64Registers) GetR16() uint64 { + if x != nil { + return x.R16 + } + return 0 +} + +func (x *ARM64Registers) GetR17() uint64 { + if x != nil { + return x.R17 + } + return 0 +} + +func (x *ARM64Registers) GetR18() uint64 { + if x != nil { + return x.R18 + } + return 0 +} + +func (x *ARM64Registers) GetR19() uint64 { + if x != nil { + return x.R19 + } + return 0 +} + +func (x *ARM64Registers) GetR20() uint64 { + if x != nil { + return x.R20 + } + return 0 +} + +func (x *ARM64Registers) GetR21() uint64 { + if x != nil { + return x.R21 + } + return 0 +} + +func (x *ARM64Registers) GetR22() uint64 { + if x != nil { + return x.R22 + } + return 0 +} + +func (x *ARM64Registers) GetR23() uint64 { + if x != nil { + return x.R23 + } + return 0 +} + +func (x *ARM64Registers) GetR24() uint64 { + if x != nil { + return x.R24 + } + return 0 +} + +func (x *ARM64Registers) GetR25() uint64 { + if x != nil { + return x.R25 + } + return 0 +} + +func (x *ARM64Registers) GetR26() uint64 { + if x != nil { + return x.R26 + } + return 0 +} + +func (x *ARM64Registers) GetR27() uint64 { + if x != nil { + return x.R27 + } + return 0 +} + +func (x *ARM64Registers) GetR28() uint64 { + if x != nil { + return x.R28 + } + return 0 +} + +func (x *ARM64Registers) GetR29() uint64 { + if x != nil { + return x.R29 + } + return 0 +} + +func (x *ARM64Registers) GetR30() uint64 { + if x != nil { + return x.R30 + } + return 0 +} + +func (x *ARM64Registers) GetSp() uint64 { + if x != nil { + return x.Sp + } + return 0 +} + +func (x *ARM64Registers) GetPc() uint64 { + if x != nil { + return x.Pc + } + return 0 +} + +func (x *ARM64Registers) GetPstate() uint64 { + if x != nil { + return x.Pstate + } + return 0 +} + +func (x *ARM64Registers) GetTls() uint64 { + if x != nil { + return x.Tls + } + return 0 +} + +type Registers struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Arch: + // *Registers_Amd64 + // *Registers_Arm64 + Arch isRegisters_Arch `protobuf_oneof:"arch"` +} + +func (x *Registers) Reset() { + *x = Registers{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_arch_registers_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Registers) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Registers) ProtoMessage() {} + +func (x *Registers) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_arch_registers_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Registers.ProtoReflect.Descriptor instead. +func (*Registers) Descriptor() ([]byte, []int) { + return file_pkg_sentry_arch_registers_proto_rawDescGZIP(), []int{2} +} + +func (m *Registers) GetArch() isRegisters_Arch { + if m != nil { + return m.Arch + } + return nil +} + +func (x *Registers) GetAmd64() *AMD64Registers { + if x, ok := x.GetArch().(*Registers_Amd64); ok { + return x.Amd64 + } + return nil +} + +func (x *Registers) GetArm64() *ARM64Registers { + if x, ok := x.GetArch().(*Registers_Arm64); ok { + return x.Arm64 + } + return nil +} + +type isRegisters_Arch interface { + isRegisters_Arch() +} + +type Registers_Amd64 struct { + Amd64 *AMD64Registers `protobuf:"bytes,1,opt,name=amd64,proto3,oneof"` +} + +type Registers_Arm64 struct { + Arm64 *ARM64Registers `protobuf:"bytes,2,opt,name=arm64,proto3,oneof"` +} + +func (*Registers_Amd64) isRegisters_Arch() {} + +func (*Registers_Arm64) isRegisters_Arch() {} + +var File_pkg_sentry_arch_registers_proto protoreflect.FileDescriptor + +var file_pkg_sentry_arch_registers_proto_rawDesc = []byte{ + 0x0a, 0x1f, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2f, 0x61, 0x72, 0x63, + 0x68, 0x2f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x06, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x22, 0x83, 0x04, 0x0a, 0x0e, 0x41, 0x4d, + 0x44, 0x36, 0x34, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x12, 0x10, 0x0a, 0x03, + 0x72, 0x61, 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x61, 0x78, 0x12, 0x10, + 0x0a, 0x03, 0x72, 0x62, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x62, 0x78, + 0x12, 0x10, 0x0a, 0x03, 0x72, 0x63, 0x78, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, + 0x63, 0x78, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x64, 0x78, 0x18, 0x04, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x03, 0x72, 0x64, 0x78, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x69, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x72, 0x73, 0x69, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x64, 0x69, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x64, 0x69, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x70, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x73, 0x70, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x62, + 0x70, 0x18, 0x08, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x62, 0x70, 0x12, 0x0e, 0x0a, 0x02, + 0x72, 0x38, 0x18, 0x09, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x72, 0x38, 0x12, 0x0e, 0x0a, 0x02, + 0x72, 0x39, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x72, 0x39, 0x12, 0x10, 0x0a, 0x03, + 0x72, 0x31, 0x30, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x30, 0x12, 0x10, + 0x0a, 0x03, 0x72, 0x31, 0x31, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x31, + 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x32, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, + 0x31, 0x32, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x33, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x03, 0x72, 0x31, 0x33, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x34, 0x18, 0x0f, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x72, 0x31, 0x34, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x35, 0x18, 0x10, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x35, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x69, 0x70, 0x18, + 0x11, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x69, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x66, + 0x6c, 0x61, 0x67, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x72, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x72, 0x69, 0x67, 0x5f, 0x72, 0x61, 0x78, 0x18, 0x13, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x6f, 0x72, 0x69, 0x67, 0x52, 0x61, 0x78, 0x12, 0x0e, 0x0a, + 0x02, 0x63, 0x73, 0x18, 0x14, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x63, 0x73, 0x12, 0x0e, 0x0a, + 0x02, 0x64, 0x73, 0x18, 0x15, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x64, 0x73, 0x12, 0x0e, 0x0a, + 0x02, 0x65, 0x73, 0x18, 0x16, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x65, 0x73, 0x12, 0x0e, 0x0a, + 0x02, 0x66, 0x73, 0x18, 0x17, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x66, 0x73, 0x12, 0x0e, 0x0a, + 0x02, 0x67, 0x73, 0x18, 0x18, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x67, 0x73, 0x12, 0x0e, 0x0a, + 0x02, 0x73, 0x73, 0x18, 0x19, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x73, 0x73, 0x12, 0x17, 0x0a, + 0x07, 0x66, 0x73, 0x5f, 0x62, 0x61, 0x73, 0x65, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, + 0x66, 0x73, 0x42, 0x61, 0x73, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x67, 0x73, 0x5f, 0x62, 0x61, 0x73, + 0x65, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x67, 0x73, 0x42, 0x61, 0x73, 0x65, 0x22, + 0xf4, 0x04, 0x0a, 0x0e, 0x41, 0x52, 0x4d, 0x36, 0x34, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, + 0x72, 0x73, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x30, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x30, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x31, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x31, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x32, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x32, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x33, 0x18, 0x04, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x33, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x34, 0x18, 0x05, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x34, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x35, 0x18, 0x06, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x35, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x36, 0x18, 0x07, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x36, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x37, 0x18, 0x08, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x37, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x38, 0x18, 0x09, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x38, 0x12, 0x0e, 0x0a, 0x02, 0x72, 0x39, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, + 0x72, 0x39, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x30, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x03, 0x72, 0x31, 0x30, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x31, 0x18, 0x0c, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x72, 0x31, 0x31, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x32, 0x18, 0x0d, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x32, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x33, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x33, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, + 0x34, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x34, 0x12, 0x10, 0x0a, 0x03, + 0x72, 0x31, 0x35, 0x18, 0x10, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x35, 0x12, 0x10, + 0x0a, 0x03, 0x72, 0x31, 0x36, 0x18, 0x11, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x31, 0x36, + 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x37, 0x18, 0x12, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, + 0x31, 0x37, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x38, 0x18, 0x13, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x03, 0x72, 0x31, 0x38, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x31, 0x39, 0x18, 0x14, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x72, 0x31, 0x39, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x30, 0x18, 0x15, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x30, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x31, 0x18, + 0x16, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x31, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, + 0x32, 0x18, 0x17, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x32, 0x12, 0x10, 0x0a, 0x03, + 0x72, 0x32, 0x33, 0x18, 0x18, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x33, 0x12, 0x10, + 0x0a, 0x03, 0x72, 0x32, 0x34, 0x18, 0x19, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x34, + 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x35, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, + 0x32, 0x35, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x36, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x03, 0x72, 0x32, 0x36, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x37, 0x18, 0x1c, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x72, 0x32, 0x37, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x38, 0x18, 0x1d, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x38, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x32, 0x39, 0x18, + 0x1e, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x32, 0x39, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x33, + 0x30, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x72, 0x33, 0x30, 0x12, 0x0e, 0x0a, 0x02, + 0x73, 0x70, 0x18, 0x20, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x73, 0x70, 0x12, 0x0e, 0x0a, 0x02, + 0x70, 0x63, 0x18, 0x21, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x70, 0x63, 0x12, 0x16, 0x0a, 0x06, + 0x70, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x22, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x70, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x6c, 0x73, 0x18, 0x23, 0x20, 0x01, 0x28, + 0x04, 0x52, 0x03, 0x74, 0x6c, 0x73, 0x22, 0x73, 0x0a, 0x09, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x05, 0x61, 0x6d, 0x64, 0x36, 0x34, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x41, 0x4d, 0x44, 0x36, + 0x34, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x48, 0x00, 0x52, 0x05, 0x61, 0x6d, + 0x64, 0x36, 0x34, 0x12, 0x2e, 0x0a, 0x05, 0x61, 0x72, 0x6d, 0x36, 0x34, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x41, 0x52, 0x4d, 0x36, + 0x34, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x48, 0x00, 0x52, 0x05, 0x61, 0x72, + 0x6d, 0x36, 0x34, 0x42, 0x06, 0x0a, 0x04, 0x61, 0x72, 0x63, 0x68, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_sentry_arch_registers_proto_rawDescOnce sync.Once + file_pkg_sentry_arch_registers_proto_rawDescData = file_pkg_sentry_arch_registers_proto_rawDesc +) + +func file_pkg_sentry_arch_registers_proto_rawDescGZIP() []byte { + file_pkg_sentry_arch_registers_proto_rawDescOnce.Do(func() { + file_pkg_sentry_arch_registers_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_sentry_arch_registers_proto_rawDescData) + }) + return file_pkg_sentry_arch_registers_proto_rawDescData +} + +var file_pkg_sentry_arch_registers_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_pkg_sentry_arch_registers_proto_goTypes = []interface{}{ + (*AMD64Registers)(nil), // 0: gvisor.AMD64Registers + (*ARM64Registers)(nil), // 1: gvisor.ARM64Registers + (*Registers)(nil), // 2: gvisor.Registers +} +var file_pkg_sentry_arch_registers_proto_depIdxs = []int32{ + 0, // 0: gvisor.Registers.amd64:type_name -> gvisor.AMD64Registers + 1, // 1: gvisor.Registers.arm64:type_name -> gvisor.ARM64Registers + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_pkg_sentry_arch_registers_proto_init() } +func file_pkg_sentry_arch_registers_proto_init() { + if File_pkg_sentry_arch_registers_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_sentry_arch_registers_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AMD64Registers); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_sentry_arch_registers_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ARM64Registers); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_sentry_arch_registers_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Registers); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pkg_sentry_arch_registers_proto_msgTypes[2].OneofWrappers = []interface{}{ + (*Registers_Amd64)(nil), + (*Registers_Arm64)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_sentry_arch_registers_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_sentry_arch_registers_proto_goTypes, + DependencyIndexes: file_pkg_sentry_arch_registers_proto_depIdxs, + MessageInfos: file_pkg_sentry_arch_registers_proto_msgTypes, + }.Build() + File_pkg_sentry_arch_registers_proto = out.File + file_pkg_sentry_arch_registers_proto_rawDesc = nil + file_pkg_sentry_arch_registers_proto_goTypes = nil + file_pkg_sentry_arch_registers_proto_depIdxs = nil +} diff --git a/pkg/sentry/contexttest/BUILD b/pkg/sentry/contexttest/BUILD deleted file mode 100644 index 6f4c86684..000000000 --- a/pkg/sentry/contexttest/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/memutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/ptrace", - "//pkg/sentry/uniqueid", - ], -) diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go deleted file mode 100644 index dfd195a23..000000000 --- a/pkg/sentry/contexttest/contexttest.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest builds a test context.Context. -package contexttest - -import ( - "os" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform. -// -// Note that some filesystems may require a minimal kernel for testing, which -// this test context does not provide. For such tests, see kernel/contexttest. -func Context(tb testing.TB) context.Context { - const memfileName = "contexttest-memory" - memfd, err := memutil.CreateMemFD(memfileName, 0) - if err != nil { - tb.Fatalf("error creating application memory file: %v", err) - } - memfile := os.NewFile(uintptr(memfd), memfileName) - mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) - if err != nil { - memfile.Close() - tb.Fatalf("error creating pgalloc.MemoryFile: %v", err) - } - p, err := ptrace.New() - if err != nil { - tb.Fatal(err) - } - // Test usage of context.Background is fine. - return &TestContext{ - Context: context.Background(), - l: limits.NewLimitSet(), - mf: mf, - platform: p, - creds: auth.NewAnonymousCredentials(), - otherValues: make(map[interface{}]interface{}), - } -} - -// TestContext represents a context with minimal functionality suitable for -// running tests. -type TestContext struct { - context.Context - l *limits.LimitSet - mf *pgalloc.MemoryFile - platform platform.Platform - creds *auth.Credentials - otherValues map[interface{}]interface{} -} - -// globalUniqueID tracks incremental unique identifiers for tests. -var globalUniqueID uint64 - -// globalUniqueIDProvider implements unix.UniqueIDProvider. -type globalUniqueIDProvider struct{} - -// UniqueID implements unix.UniqueIDProvider.UniqueID. -func (*globalUniqueIDProvider) UniqueID() uint64 { - return atomic.AddUint64(&globalUniqueID, 1) -} - -// lastInotifyCookie is a monotonically increasing counter for generating unique -// inotify cookies. Must be accessed using atomic ops. -var lastInotifyCookie uint32 - -// hostClock implements ktime.Clock. -type hostClock struct { - ktime.WallRateClock - ktime.NoClockEvents -} - -// Now implements ktime.Clock.Now. -func (*hostClock) Now() ktime.Time { - return ktime.FromNanoseconds(time.Now().UnixNano()) -} - -// RegisterValue registers additional values with this test context. Useful for -// providing values from external packages that contexttest can't depend on. -func (t *TestContext) RegisterValue(key, value interface{}) { - t.otherValues[key] = value -} - -// Value implements context.Context. -func (t *TestContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return t.creds - case limits.CtxLimits: - return t.l - case pgalloc.CtxMemoryFile: - return t.mf - case pgalloc.CtxMemoryFileProvider: - return t - case platform.CtxPlatform: - return t.platform - case uniqueid.CtxGlobalUniqueID: - return (*globalUniqueIDProvider).UniqueID(nil) - case uniqueid.CtxGlobalUniqueIDProvider: - return &globalUniqueIDProvider{} - case uniqueid.CtxInotifyCookie: - return atomic.AddUint32(&lastInotifyCookie, 1) - case ktime.CtxRealtimeClock: - return &hostClock{} - default: - if val, ok := t.otherValues[key]; ok { - return val - } - return t.Context.Value(key) - } -} - -// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile. -func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { - return t.mf -} - -// RootContext returns a Context that may be used in tests that need root -// credentials. Uses ptrace as the platform.Platform. -func RootContext(tb testing.TB) context.Context { - return auth.ContextWithCredentials(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) -} - -// WithLimitSet returns a copy of ctx carrying l. -func WithLimitSet(ctx context.Context, l *limits.LimitSet) context.Context { - return limitContext{ctx, l} -} - -type limitContext struct { - context.Context - l *limits.LimitSet -} - -// Value implements context.Context. -func (lc limitContext) Value(key interface{}) interface{} { - switch key { - case limits.CtxLimits: - return lc.l - default: - return lc.Context.Value(key) - } -} diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD deleted file mode 100644 index a4934a565..000000000 --- a/pkg/sentry/control/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "control", - srcs = [ - "control.go", - "events.go", - "fs.go", - "lifecycle.go", - "logging.go", - "pprof.go", - "proc.go", - "state.go", - "usage.go", - ], - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/eventchannel", - "//pkg/fd", - "//pkg/log", - "//pkg/sentry/fdimport", - "//pkg/sentry/fs", - "//pkg/sentry/fs/host", - "//pkg/sentry/fs/user", - "//pkg/sentry/fsimpl/host", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/state", - "//pkg/sentry/strace", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sentry/watchdog", - "//pkg/sync", - "//pkg/tcpip/link/sniffer", - "//pkg/urpc", - "//pkg/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "control_test", - size = "small", - srcs = ["proc_test.go"], - library = ":control", - deps = [ - "//pkg/log", - "//pkg/sentry/kernel/time", - "//pkg/sentry/usage", - ], -) diff --git a/pkg/sentry/control/control_state_autogen.go b/pkg/sentry/control/control_state_autogen.go new file mode 100644 index 000000000..bd5797221 --- /dev/null +++ b/pkg/sentry/control/control_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package control diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go deleted file mode 100644 index 0a88459b2..000000000 --- a/pkg/sentry/control/proc_test.go +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package control - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/log" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usage" -) - -func init() { - log.SetLevel(log.Debug) -} - -// Tests that ProcessData.Table() prints with the correct format. -func TestProcessListTable(t *testing.T) { - testCases := []struct { - pl []*Process - expected string - }{ - { - pl: []*Process{}, - expected: "UID PID PPID C TTY STIME TIME CMD", - }, - { - pl: []*Process{ - { - UID: 0, - PID: 0, - PPID: 0, - C: 0, - TTY: "?", - STime: "0", - Time: "0", - Cmd: "zero", - }, - { - UID: 1, - PID: 1, - PPID: 1, - C: 1, - TTY: "pts/4", - STime: "1", - Time: "1", - Cmd: "one", - }, - }, - expected: `UID PID PPID C TTY STIME TIME CMD -0 0 0 0 ? 0 0 zero -1 1 1 1 pts/4 1 1 one`, - }, - } - - for _, tc := range testCases { - output := ProcessListToTable(tc.pl) - - if tc.expected != output { - t.Errorf("PrintTable(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected) - } - } -} - -func TestProcessListJSON(t *testing.T) { - testCases := []struct { - pl []*Process - expected string - }{ - { - pl: []*Process{}, - expected: "[]", - }, - { - pl: []*Process{ - { - UID: 0, - PID: 0, - PPID: 0, - C: 0, - STime: "0", - Time: "0", - Cmd: "zero", - }, - { - UID: 1, - PID: 1, - PPID: 1, - C: 1, - STime: "1", - Time: "1", - Cmd: "one", - }, - }, - expected: "[0,1]", - }, - } - - for _, tc := range testCases { - output, err := PrintPIDsJSON(tc.pl) - if err != nil { - t.Errorf("failed to generate JSON: %v", err) - } - - if tc.expected != output { - t.Errorf("PrintJSON(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected) - } - } -} - -func TestPercentCPU(t *testing.T) { - testCases := []struct { - stats usage.CPUStats - startTime ktime.Time - now ktime.Time - expected int32 - }{ - { - // Verify that 100% use is capped at 99. - stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9}, - startTime: ktime.FromNanoseconds(7e9), - now: ktime.FromNanoseconds(9e9), - expected: 99, - }, - { - // Verify that if usage > lifetime, we get at most 99% - // usage. - stats: usage.CPUStats{UserTime: 2e9, SysTime: 2e9}, - startTime: ktime.FromNanoseconds(7e9), - now: ktime.FromNanoseconds(9e9), - expected: 99, - }, - { - // Verify that 50% usage is reported correctly. - stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9}, - startTime: ktime.FromNanoseconds(12e9), - now: ktime.FromNanoseconds(16e9), - expected: 50, - }, - { - // Verify that 0% usage is reported correctly. - stats: usage.CPUStats{UserTime: 0, SysTime: 0}, - startTime: ktime.FromNanoseconds(12e9), - now: ktime.FromNanoseconds(14e9), - expected: 0, - }, - } - - for _, tc := range testCases { - if pcpu := percentCPU(tc.stats, tc.startTime, tc.now); pcpu != tc.expected { - t.Errorf("percentCPU(%v, %v, %v): got %d, want %d", tc.stats, tc.startTime, tc.now, pcpu, tc.expected) - } - } -} diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD deleted file mode 100644 index e403cbd8b..000000000 --- a/pkg/sentry/device/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "device", - srcs = ["device.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sync", - ], -) - -go_test( - name = "device_test", - size = "small", - srcs = ["device_test.go"], - library = ":device", -) diff --git a/pkg/sentry/device/device_state_autogen.go b/pkg/sentry/device/device_state_autogen.go new file mode 100644 index 000000000..b68325291 --- /dev/null +++ b/pkg/sentry/device/device_state_autogen.go @@ -0,0 +1,97 @@ +// automatically generated by stateify. + +package device + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *Registry) StateTypeName() string { + return "pkg/sentry/device.Registry" +} + +func (r *Registry) StateFields() []string { + return []string{ + "lastAnonDeviceMinor", + "devices", + } +} + +func (r *Registry) beforeSave() {} + +// +checklocksignore +func (r *Registry) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.lastAnonDeviceMinor) + stateSinkObject.Save(1, &r.devices) +} + +func (r *Registry) afterLoad() {} + +// +checklocksignore +func (r *Registry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.lastAnonDeviceMinor) + stateSourceObject.Load(1, &r.devices) +} + +func (i *ID) StateTypeName() string { + return "pkg/sentry/device.ID" +} + +func (i *ID) StateFields() []string { + return []string{ + "Major", + "Minor", + } +} + +func (i *ID) beforeSave() {} + +// +checklocksignore +func (i *ID) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.Major) + stateSinkObject.Save(1, &i.Minor) +} + +func (i *ID) afterLoad() {} + +// +checklocksignore +func (i *ID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.Major) + stateSourceObject.Load(1, &i.Minor) +} + +func (d *Device) StateTypeName() string { + return "pkg/sentry/device.Device" +} + +func (d *Device) StateFields() []string { + return []string{ + "ID", + "last", + } +} + +func (d *Device) beforeSave() {} + +// +checklocksignore +func (d *Device) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.ID) + stateSinkObject.Save(1, &d.last) +} + +func (d *Device) afterLoad() {} + +// +checklocksignore +func (d *Device) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.ID) + stateSourceObject.Load(1, &d.last) +} + +func init() { + state.Register((*Registry)(nil)) + state.Register((*ID)(nil)) + state.Register((*Device)(nil)) +} diff --git a/pkg/sentry/device/device_test.go b/pkg/sentry/device/device_test.go deleted file mode 100644 index e3f51ce4f..000000000 --- a/pkg/sentry/device/device_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package device - -import ( - "testing" -) - -func TestMultiDevice(t *testing.T) { - device := &MultiDevice{} - - // Check that Load fails to install virtual inodes that are - // uninitialized. - if device.Load(MultiDeviceKey{}, 0) { - t.Fatalf("got load of invalid virtual inode 0, want unsuccessful") - } - - inode := device.Map(MultiDeviceKey{}) - - // Assert that the same raw device and inode map to - // a consistent virtual inode. - if i := device.Map(MultiDeviceKey{}); i != inode { - t.Fatalf("got inode %d, want %d in %s", i, inode, device) - } - - // Assert that a new inode or new device does not conflict. - if i := device.Map(MultiDeviceKey{Device: 0, Inode: 1}); i == inode { - t.Fatalf("got reused inode %d, want new distinct inode in %s", i, device) - } - last := device.Map(MultiDeviceKey{Device: 1, Inode: 0}) - if last == inode { - t.Fatalf("got reused inode %d, want new distinct inode in %s", last, device) - } - - // Virtual is the virtual inode we want to load. - virtual := last + 1 - - // Assert that we can load a virtual inode at a new place. - if !device.Load(MultiDeviceKey{Device: 0, Inode: 2}, virtual) { - t.Fatalf("got load of virtual inode %d failed, want success in %s", virtual, device) - } - - // Assert that the next inode skips over the loaded one. - if i := device.Map(MultiDeviceKey{Device: 0, Inode: 3}); i != virtual+1 { - t.Fatalf("got inode %d, want %d in %s", i, virtual+1, device) - } -} diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD deleted file mode 100644 index 66b9ed523..000000000 --- a/pkg/sentry/devices/memdev/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "memdev", - srcs = [ - "full.go", - "memdev.go", - "null.go", - "random.go", - "zero.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/devices/memdev/memdev_state_autogen.go b/pkg/sentry/devices/memdev/memdev_state_autogen.go new file mode 100644 index 000000000..a8faafbbf --- /dev/null +++ b/pkg/sentry/devices/memdev/memdev_state_autogen.go @@ -0,0 +1,241 @@ +// automatically generated by stateify. + +package memdev + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *fullDevice) StateTypeName() string { + return "pkg/sentry/devices/memdev.fullDevice" +} + +func (f *fullDevice) StateFields() []string { + return []string{} +} + +func (f *fullDevice) beforeSave() {} + +// +checklocksignore +func (f *fullDevice) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *fullDevice) afterLoad() {} + +// +checklocksignore +func (f *fullDevice) StateLoad(stateSourceObject state.Source) { +} + +func (fd *fullFD) StateTypeName() string { + return "pkg/sentry/devices/memdev.fullFD" +} + +func (fd *fullFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + } +} + +func (fd *fullFD) beforeSave() {} + +// +checklocksignore +func (fd *fullFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.NoLockFD) +} + +func (fd *fullFD) afterLoad() {} + +// +checklocksignore +func (fd *fullFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.NoLockFD) +} + +func (n *nullDevice) StateTypeName() string { + return "pkg/sentry/devices/memdev.nullDevice" +} + +func (n *nullDevice) StateFields() []string { + return []string{} +} + +func (n *nullDevice) beforeSave() {} + +// +checklocksignore +func (n *nullDevice) StateSave(stateSinkObject state.Sink) { + n.beforeSave() +} + +func (n *nullDevice) afterLoad() {} + +// +checklocksignore +func (n *nullDevice) StateLoad(stateSourceObject state.Source) { +} + +func (fd *nullFD) StateTypeName() string { + return "pkg/sentry/devices/memdev.nullFD" +} + +func (fd *nullFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + } +} + +func (fd *nullFD) beforeSave() {} + +// +checklocksignore +func (fd *nullFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.NoLockFD) +} + +func (fd *nullFD) afterLoad() {} + +// +checklocksignore +func (fd *nullFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.NoLockFD) +} + +func (r *randomDevice) StateTypeName() string { + return "pkg/sentry/devices/memdev.randomDevice" +} + +func (r *randomDevice) StateFields() []string { + return []string{} +} + +func (r *randomDevice) beforeSave() {} + +// +checklocksignore +func (r *randomDevice) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *randomDevice) afterLoad() {} + +// +checklocksignore +func (r *randomDevice) StateLoad(stateSourceObject state.Source) { +} + +func (fd *randomFD) StateTypeName() string { + return "pkg/sentry/devices/memdev.randomFD" +} + +func (fd *randomFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "off", + } +} + +func (fd *randomFD) beforeSave() {} + +// +checklocksignore +func (fd *randomFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.NoLockFD) + stateSinkObject.Save(4, &fd.off) +} + +func (fd *randomFD) afterLoad() {} + +// +checklocksignore +func (fd *randomFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.NoLockFD) + stateSourceObject.Load(4, &fd.off) +} + +func (z *zeroDevice) StateTypeName() string { + return "pkg/sentry/devices/memdev.zeroDevice" +} + +func (z *zeroDevice) StateFields() []string { + return []string{} +} + +func (z *zeroDevice) beforeSave() {} + +// +checklocksignore +func (z *zeroDevice) StateSave(stateSinkObject state.Sink) { + z.beforeSave() +} + +func (z *zeroDevice) afterLoad() {} + +// +checklocksignore +func (z *zeroDevice) StateLoad(stateSourceObject state.Source) { +} + +func (fd *zeroFD) StateTypeName() string { + return "pkg/sentry/devices/memdev.zeroFD" +} + +func (fd *zeroFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + } +} + +func (fd *zeroFD) beforeSave() {} + +// +checklocksignore +func (fd *zeroFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.NoLockFD) +} + +func (fd *zeroFD) afterLoad() {} + +// +checklocksignore +func (fd *zeroFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.NoLockFD) +} + +func init() { + state.Register((*fullDevice)(nil)) + state.Register((*fullFD)(nil)) + state.Register((*nullDevice)(nil)) + state.Register((*nullFD)(nil)) + state.Register((*randomDevice)(nil)) + state.Register((*randomFD)(nil)) + state.Register((*zeroDevice)(nil)) + state.Register((*zeroFD)(nil)) +} diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD deleted file mode 100644 index ee946610a..000000000 --- a/pkg/sentry/devices/quotedev/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "quotedev", - srcs = ["quotedev.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go deleted file mode 100644 index 140856a4a..000000000 --- a/pkg/sentry/devices/quotedev/quotedev.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package quotedev implements a vfs.Device for /dev/gvisor_quote. -package quotedev - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -const ( - quoteDevMinor = 0 -) - -// quoteDevice implements vfs.Device for /dev/gvisor_quote -// -// +stateify savable -type quoteDevice struct{} - -// Open implements vfs.Device.Open. -// TODO(b/157161182): Add support for attestation ioctls. -func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return nil, linuxerr.EIO -} - -// Register registers all devices implemented by this package in vfsObj. -func Register(vfsObj *vfs.VirtualFilesystem) error { - return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{ - GroupName: "gvisor_quote", - }) -} - -// CreateDevtmpfsFiles creates device special files in dev representing all -// devices implemented by this package. -func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { - return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */) -} diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD deleted file mode 100644 index ab4cd0b33..000000000 --- a/pkg/sentry/devices/ttydev/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "ttydev", - srcs = ["ttydev.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/devices/ttydev/ttydev_state_autogen.go b/pkg/sentry/devices/ttydev/ttydev_state_autogen.go new file mode 100644 index 000000000..88a1cd0a6 --- /dev/null +++ b/pkg/sentry/devices/ttydev/ttydev_state_autogen.go @@ -0,0 +1,32 @@ +// automatically generated by stateify. + +package ttydev + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *ttyDevice) StateTypeName() string { + return "pkg/sentry/devices/ttydev.ttyDevice" +} + +func (t *ttyDevice) StateFields() []string { + return []string{} +} + +func (t *ttyDevice) beforeSave() {} + +// +checklocksignore +func (t *ttyDevice) StateSave(stateSinkObject state.Sink) { + t.beforeSave() +} + +func (t *ttyDevice) afterLoad() {} + +// +checklocksignore +func (t *ttyDevice) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*ttyDevice)(nil)) +} diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD deleted file mode 100644 index 60c971030..000000000 --- a/pkg/sentry/devices/tundev/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "tundev", - srcs = ["tundev.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/arch", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/socket/netstack", - "//pkg/sentry/vfs", - "//pkg/tcpip/link/tun", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/devices/tundev/tundev_state_autogen.go b/pkg/sentry/devices/tundev/tundev_state_autogen.go new file mode 100644 index 000000000..ad7c7feb4 --- /dev/null +++ b/pkg/sentry/devices/tundev/tundev_state_autogen.go @@ -0,0 +1,70 @@ +// automatically generated by stateify. + +package tundev + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *tunDevice) StateTypeName() string { + return "pkg/sentry/devices/tundev.tunDevice" +} + +func (t *tunDevice) StateFields() []string { + return []string{} +} + +func (t *tunDevice) beforeSave() {} + +// +checklocksignore +func (t *tunDevice) StateSave(stateSinkObject state.Sink) { + t.beforeSave() +} + +func (t *tunDevice) afterLoad() {} + +// +checklocksignore +func (t *tunDevice) StateLoad(stateSourceObject state.Source) { +} + +func (fd *tunFD) StateTypeName() string { + return "pkg/sentry/devices/tundev.tunFD" +} + +func (fd *tunFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "device", + } +} + +func (fd *tunFD) beforeSave() {} + +// +checklocksignore +func (fd *tunFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.NoLockFD) + stateSinkObject.Save(4, &fd.device) +} + +func (fd *tunFD) afterLoad() {} + +// +checklocksignore +func (fd *tunFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.NoLockFD) + stateSourceObject.Load(4, &fd.device) +} + +func init() { + state.Register((*tunDevice)(nil)) + state.Register((*tunFD)(nil)) +} diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD deleted file mode 100644 index 563e96e0d..000000000 --- a/pkg/sentry/fdimport/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fdimport", - srcs = [ - "fdimport.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/fd", - "//pkg/sentry/fs", - "//pkg/sentry/fs/host", - "//pkg/sentry/fsimpl/host", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/fdimport/fdimport_state_autogen.go b/pkg/sentry/fdimport/fdimport_state_autogen.go new file mode 100644 index 000000000..fbb89b276 --- /dev/null +++ b/pkg/sentry/fdimport/fdimport_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fdimport diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD deleted file mode 100644 index 4e573d249..000000000 --- a/pkg/sentry/fs/BUILD +++ /dev/null @@ -1,138 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_library( - name = "fs", - srcs = [ - "attr.go", - "context.go", - "copy_up.go", - "dentry.go", - "dirent.go", - "dirent_cache.go", - "dirent_cache_limiter.go", - "dirent_list.go", - "dirent_state.go", - "event_list.go", - "file.go", - "file_operations.go", - "file_overlay.go", - "file_state.go", - "filesystems.go", - "flags.go", - "fs.go", - "inode.go", - "inode_inotify.go", - "inode_operations.go", - "inode_overlay.go", - "inotify.go", - "inotify_event.go", - "inotify_watch.go", - "mock.go", - "mount.go", - "mount_overlay.go", - "mounts.go", - "offset.go", - "overlay.go", - "path.go", - "restore.go", - "save.go", - "seek.go", - "splice.go", - "sync.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/p9", - "//pkg/refs", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsmetric", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/platform", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/state", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_template_instance( - name = "dirent_list", - out = "dirent_list.go", - package = "fs", - prefix = "dirent", - template = "//pkg/ilist:generic_list", - types = { - "Linker": "*Dirent", - "Element": "*Dirent", - }, -) - -go_template_instance( - name = "event_list", - out = "event_list.go", - package = "fs", - prefix = "event", - template = "//pkg/ilist:generic_list", - types = { - "Linker": "*Event", - "Element": "*Event", - }, -) - -go_test( - name = "fs_x_test", - size = "small", - srcs = [ - "copy_up_test.go", - "file_overlay_test.go", - "inode_overlay_test.go", - "mounts_test.go", - ], - deps = [ - ":fs", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/contexttest", - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "fs_test", - size = "small", - srcs = [ - "dirent_cache_test.go", - "dirent_refs_test.go", - "mount_test.go", - "path_test.go", - ], - library = ":fs", - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/fs/README.md b/pkg/sentry/fs/README.md deleted file mode 100644 index db4a1b730..000000000 --- a/pkg/sentry/fs/README.md +++ /dev/null @@ -1,229 +0,0 @@ -This package provides an implementation of the Linux virtual filesystem. - -[TOC] - -## Overview - -- An `fs.Dirent` caches an `fs.Inode` in memory at a path in the VFS, giving - the `fs.Inode` a relative position with respect to other `fs.Inode`s. - -- If an `fs.Dirent` is referenced by two file descriptors, then those file - descriptors are coherent with each other: they depend on the same - `fs.Inode`. - -- A mount point is an `fs.Dirent` for which `fs.Dirent.mounted` is true. It - exposes the root of a mounted filesystem. - -- The `fs.Inode` produced by a registered filesystem on mount(2) owns an - `fs.MountedFilesystem` from which other `fs.Inode`s will be looked up. For a - remote filesystem, the `fs.MountedFilesystem` owns the connection to that - remote filesystem. - -- In general: - -``` -fs.Inode <------------------------------ -| | -| | -produced by | -exactly one | -| responsible for the -| virtual identity of -v | -fs.MountedFilesystem ------------------- -``` - -Glossary: - -- VFS: virtual filesystem. - -- inode: a virtual file object holding a cached view of a file on a backing - filesystem (includes metadata and page caches). - -- superblock: the virtual state of a mounted filesystem (e.g. the virtual - inode number set). - -- mount namespace: a view of the mounts under a root (during path traversal, - the VFS makes visible/follows the mount point that is in the current task's - mount namespace). - -## Save and restore - -An application's hard dependencies on filesystem state can be broken down into -two categories: - -- The state necessary to execute a traversal on or view the *virtual* - filesystem hierarchy, regardless of what files an application has open. - -- The state necessary to represent open files. - -The first is always necessary to save and restore. An application may never have -any open file descriptors, but across save and restore it should see a coherent -view of any mount namespace. NOTE(b/63601033): Currently only one "initial" -mount namespace is supported. - -The second is so that system calls across save and restore are coherent with -each other (e.g. so that unintended re-reads or overwrites do not occur). - -Specifically this state is: - -- An `fs.MountManager` containing mount points. - -- A `kernel.FDTable` containing pointers to open files. - -Anything else managed by the VFS that can be easily loaded into memory from a -filesystem is synced back to those filesystems and is not saved. Examples are -pages in page caches used for optimizations (i.e. readahead and writeback), and -directory entries used to accelerate path lookups. - -### Mount points - -Saving and restoring a mount point means saving and restoring: - -- The root of the mounted filesystem. - -- Mount flags, which control how the VFS interacts with the mounted - filesystem. - -- Any relevant metadata about the mounted filesystem. - -- All `fs.Inode`s referenced by the application that reside under the mount - point. - -`fs.MountedFilesystem` is metadata about a filesystem that is mounted. It is -referenced by every `fs.Inode` loaded into memory under the mount point -including the `fs.Inode` of the mount point itself. The `fs.MountedFilesystem` -maps file objects on the filesystem to a virtualized `fs.Inode` number and vice -versa. - -To restore all `fs.Inode`s under a given mount point, each `fs.Inode` leverages -its dependency on an `fs.MountedFilesystem`. Since the `fs.MountedFilesystem` -knows how an `fs.Inode` maps to a file object on a backing filesystem, this -mapping can be trivially consulted by each `fs.Inode` when the `fs.Inode` is -restored. - -In detail, a mount point is saved in two steps: - -- First, after the kernel is paused but before state.Save, we walk all mount - namespaces and install a mapping from `fs.Inode` numbers to file paths - relative to the root of the mounted filesystem in each - `fs.MountedFilesystem`. This is subsequently called the set of `fs.Inode` - mappings. - -- Second, during state.Save, each `fs.MountedFilesystem` decides whether to - save the set of `fs.Inode` mappings. In-memory filesystems, like tmpfs, have - no need to save a set of `fs.Inode` mappings, since the `fs.Inode`s can be - entirely encoded in state file. Each `fs.MountedFilesystem` also optionally - saves the device name from when the filesystem was originally mounted. Each - `fs.Inode` saves its virtual identifier and a reference to a - `fs.MountedFilesystem`. - -A mount point is restored in two steps: - -- First, before state.Load, all mount configurations are stored in a global - `fs.RestoreEnvironment`. This tells us what mount points the user wants to - restore and how to re-establish pointers to backing filesystems. - -- Second, during state.Load, each `fs.MountedFilesystem` optionally searches - for a mount in the `fs.RestoreEnvironment` that matches its saved device - name. The `fs.MountedFilesystem` then reestablishes a pointer to the root of - the mounted filesystem. For example, the mount specification provides the - network connection for a mounted remote filesystem client to communicate - with its remote file server. The `fs.MountedFilesystem` also trivially loads - its set of `fs.Inode` mappings. When an `fs.Inode` is encountered, the - `fs.Inode` loads its virtual identifier and its reference a - `fs.MountedFilesystem`. It uses the `fs.MountedFilesystem` to obtain the - root of the mounted filesystem and the `fs.Inode` mappings to obtain the - relative file path to its data. With these, the `fs.Inode` re-establishes a - pointer to its file object. - -A mount point can trivially restore its `fs.Inode`s in parallel since -`fs.Inode`s have a restore dependency on their `fs.MountedFilesystem` and not on -each other. - -### Open files - -An `fs.File` references the following filesystem objects: - -```go -fs.File -> fs.Dirent -> fs.Inode -> fs.MountedFilesystem -``` - -The `fs.Inode` is restored using its `fs.MountedFilesystem`. The -[Mount points](#mount-points) section above describes how this happens in -detail. The `fs.Dirent` restores its pointer to an `fs.Inode`, pointers to -parent and children `fs.Dirents`, and the basename of the file. - -Otherwise an `fs.File` restores flags, an offset, and a unique identifier (only -used internally). - -It may use the `fs.Inode`, which it indirectly holds a reference on through the -`fs.Dirent`, to reestablish an open file handle on the backing filesystem (e.g. -to continue reading and writing). - -## Overlay - -The overlay implementation in the fs package takes Linux overlayfs as a frame of -reference but corrects for several POSIX consistency errors. - -In Linux overlayfs, the `struct inode` used for reading and writing to the same -file may be different. This is because the `struct inode` is dissociated with -the process of copying up the file from the upper to the lower directory. Since -flock(2) and fcntl(2) locks, inotify(7) watches, page caches, and a file's -identity are all stored directly or indirectly off the `struct inode`, these -properties of the `struct inode` may be stale after the first modification. This -can lead to file locking bugs, missed inotify events, and inconsistent data in -shared memory mappings of files, to name a few problems. - -The fs package maintains a single `fs.Inode` to represent a directory entry in -an overlay and defines operations on this `fs.Inode` which synchronize with the -copy up process. This achieves several things: - -+ File locks, inotify watches, and the identity of the file need not be copied - at all. - -+ Memory mappings of files coordinate with the copy up process so that if a - file in the lower directory is memory mapped, all references to it are - invalidated, forcing the application to re-fault on memory mappings of the - file under the upper directory. - -The `fs.Inode` holds metadata about files in the upper and/or lower directories -via an `fs.overlayEntry`. The `fs.overlayEntry` implements the `fs.Mappable` -interface. It multiplexes between upper and lower directory memory mappings and -stores a copy of memory references so they can be transferred to the upper -directory `fs.Mappable` when the file is copied up. - -The lower filesystem in an overlay may contain another (nested) overlay, but the -upper filesystem may not contain another overlay. In other words, nested -overlays form a tree structure that only allows branching in the lower -filesystem. - -Caching decisions in the overlay are delegated to the upper filesystem, meaning -that the Keep and Revalidate methods on the overlay return the same values as -the upper filesystem. A small wrinkle is that the lower filesystem is not -allowed to return `true` from Revalidate, as the overlay can not reload inodes -from the lower filesystem. A lower filesystem that does return `true` from -Revalidate will trigger a panic. - -The `fs.Inode` also holds a reference to a `fs.MountedFilesystem` that -normalizes across the mounted filesystem state of the upper and lower -directories. - -When a file is copied from the lower to the upper directory, attempts to -interact with the file block until the copy completes. All copying synchronizes -with rename(2). - -## Future Work - -### Overlay - -When a file is copied from a lower directory to an upper directory, several -locks are taken: the global renamuMu and the copyMu of the `fs.Inode` being -copied. This blocks operations on the file, including fault handling of memory -mappings. Performance could be improved by copying files into a temporary -directory that resides on the same filesystem as the upper directory and doing -an atomic rename, holding locks only during the rename operation. - -Additionally files are copied up synchronously. For large files, this causes a -noticeable latency. Performance could be improved by pipelining copies at -non-overlapping file offsets. diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD deleted file mode 100644 index 1ce56d79f..000000000 --- a/pkg/sentry/fs/anon/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "anon", - srcs = [ - "anon.go", - "device.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/hostarch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - ], -) diff --git a/pkg/sentry/fs/anon/anon_state_autogen.go b/pkg/sentry/fs/anon/anon_state_autogen.go new file mode 100644 index 000000000..b2b1a466e --- /dev/null +++ b/pkg/sentry/fs/anon/anon_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package anon diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go deleted file mode 100644 index e04784db2..000000000 --- a/pkg/sentry/fs/copy_up_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/fs" - _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - // origFileSize is the original file size. This many bytes should be - // copied up before the test file is modified. - origFileSize = 4096 - - // truncatedFileSize is the size to truncate all test files. - truncateFileSize = 10 -) - -// TestConcurrentCopyUp is a copy up stress test for an overlay. -// -// It creates a 64-level deep directory tree in the lower filesystem and -// populates the last subdirectory with 64 files containing random content: -// -// /lower -// /sudir0/.../subdir63/ -// /file0 -// ... -// /file63 -// -// The files are truncated concurrently by 4 goroutines per file. -// These goroutines contend with copying up all parent 64 subdirectories -// as well as the final file content. -// -// At the end of the test, we assert that the files respect the new truncated -// size and contain the content we expect. -func TestConcurrentCopyUp(t *testing.T) { - ctx := contexttest.Context(t) - files := makeOverlayTestFiles(t) - - var wg sync.WaitGroup - for _, file := range files { - for i := 0; i < 4; i++ { - wg.Add(1) - go func(o *overlayTestFile) { - if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Errorf("failed to copy up: %v", err) - } - wg.Done() - }(file) - } - } - wg.Wait() - - for _, file := range files { - got := make([]byte, origFileSize) - n, err := file.File.Readv(ctx, usermem.BytesIOSequence(got)) - if int(n) != truncateFileSize { - t.Fatalf("read %d bytes from file, want %d", n, truncateFileSize) - } - if err != nil && err != io.EOF { - t.Fatalf("read got error %v, want nil", err) - } - if !bytes.Equal(got[:n], file.content[:truncateFileSize]) { - t.Fatalf("file content is %v, want %v", got[:n], file.content[:truncateFileSize]) - } - } -} - -type overlayTestFile struct { - File *fs.File - name string - content []byte -} - -func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { - ctx := contexttest.Context(t) - - // Create a lower tmpfs mount. - fsys, _ := fs.FindFilesystem("tmpfs") - lower, err := fsys.Mount(contexttest.Context(t), "", fs.MountSourceFlags{}, "", nil) - if err != nil { - t.Fatalf("failed to mount tmpfs: %v", err) - } - lowerRoot := fs.NewDirent(ctx, lower, "") - - // Make a deep set of subdirectories that everyone shares. - next := lowerRoot - for i := 0; i < 64; i++ { - name := fmt.Sprintf("subdir%d", i) - err := next.CreateDirectory(ctx, lowerRoot, name, fs.FilePermsFromMode(0777)) - if err != nil { - t.Fatalf("failed to create dir %q: %v", name, err) - } - next, err = next.Walk(ctx, lowerRoot, name) - if err != nil { - t.Fatalf("failed to walk to %q: %v", name, err) - } - } - - // Make a bunch of files in the last directory. - var files []*overlayTestFile - for i := 0; i < 64; i++ { - name := fmt.Sprintf("file%d", i) - f, err := next.Create(ctx, next, name, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - t.Fatalf("failed to create file %q: %v", name, err) - } - defer f.DecRef(ctx) - - relname, _ := f.Dirent.FullName(lowerRoot) - - o := &overlayTestFile{ - name: relname, - content: make([]byte, origFileSize), - } - - if _, err := rand.Read(o.content); err != nil { - t.Fatalf("failed to read from /dev/urandom: %v", err) - } - - if _, err := f.Writev(ctx, usermem.BytesIOSequence(o.content)); err != nil { - t.Fatalf("failed to write content to file %q: %v", name, err) - } - - files = append(files, o) - } - - // Create an empty upper tmpfs mount which we will copy up into. - upper, err := fsys.Mount(ctx, "", fs.MountSourceFlags{}, "", nil) - if err != nil { - t.Fatalf("failed to mount tmpfs: %v", err) - } - - // Construct an overlay root. - overlay, err := fs.NewOverlayRoot(ctx, upper, lower, fs.MountSourceFlags{}) - if err != nil { - t.Fatalf("failed to construct overlay root: %v", err) - } - - // Create a MountNamespace to traverse the file system. - mns, err := fs.NewMountNamespace(ctx, overlay) - if err != nil { - t.Fatalf("failed to construct mount manager: %v", err) - } - - // Walk to all of the files in the overlay, open them readable. - for _, f := range files { - maxTraversals := uint(0) - d, err := mns.FindInode(ctx, mns.Root(), mns.Root(), f.name, &maxTraversals) - if err != nil { - t.Fatalf("failed to find %q: %v", f.name, err) - } - defer d.DecRef(ctx) - - f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("failed to open file %q readable: %v", f.name, err) - } - } - - return files -} diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD deleted file mode 100644 index 7baf26b24..000000000 --- a/pkg/sentry/fs/dev/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "dev", - srcs = [ - "dev.go", - "device.go", - "fs.go", - "full.go", - "net_tun.go", - "null.go", - "random.go", - "tty.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/socket/netstack", - "//pkg/tcpip/link/tun", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/dev/dev_state_autogen.go b/pkg/sentry/fs/dev/dev_state_autogen.go new file mode 100644 index 000000000..23945d391 --- /dev/null +++ b/pkg/sentry/fs/dev/dev_state_autogen.go @@ -0,0 +1,324 @@ +// automatically generated by stateify. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *filesystem) StateTypeName() string { + return "pkg/sentry/fs/dev.filesystem" +} + +func (f *filesystem) StateFields() []string { + return []string{} +} + +func (f *filesystem) beforeSave() {} + +// +checklocksignore +func (f *filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystem) afterLoad() {} + +// +checklocksignore +func (f *filesystem) StateLoad(stateSourceObject state.Source) { +} + +func (f *fullDevice) StateTypeName() string { + return "pkg/sentry/fs/dev.fullDevice" +} + +func (f *fullDevice) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (f *fullDevice) beforeSave() {} + +// +checklocksignore +func (f *fullDevice) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.InodeSimpleAttributes) +} + +func (f *fullDevice) afterLoad() {} + +// +checklocksignore +func (f *fullDevice) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.InodeSimpleAttributes) +} + +func (f *fullFileOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.fullFileOperations" +} + +func (f *fullFileOperations) StateFields() []string { + return []string{} +} + +func (f *fullFileOperations) beforeSave() {} + +// +checklocksignore +func (f *fullFileOperations) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *fullFileOperations) afterLoad() {} + +// +checklocksignore +func (f *fullFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func (n *netTunInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.netTunInodeOperations" +} + +func (n *netTunInodeOperations) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (n *netTunInodeOperations) beforeSave() {} + +// +checklocksignore +func (n *netTunInodeOperations) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.InodeSimpleAttributes) +} + +func (n *netTunInodeOperations) afterLoad() {} + +// +checklocksignore +func (n *netTunInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.InodeSimpleAttributes) +} + +func (n *netTunFileOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.netTunFileOperations" +} + +func (n *netTunFileOperations) StateFields() []string { + return []string{ + "device", + } +} + +func (n *netTunFileOperations) beforeSave() {} + +// +checklocksignore +func (n *netTunFileOperations) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.device) +} + +func (n *netTunFileOperations) afterLoad() {} + +// +checklocksignore +func (n *netTunFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.device) +} + +func (n *nullDevice) StateTypeName() string { + return "pkg/sentry/fs/dev.nullDevice" +} + +func (n *nullDevice) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (n *nullDevice) beforeSave() {} + +// +checklocksignore +func (n *nullDevice) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.InodeSimpleAttributes) +} + +func (n *nullDevice) afterLoad() {} + +// +checklocksignore +func (n *nullDevice) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.InodeSimpleAttributes) +} + +func (n *nullFileOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.nullFileOperations" +} + +func (n *nullFileOperations) StateFields() []string { + return []string{} +} + +func (n *nullFileOperations) beforeSave() {} + +// +checklocksignore +func (n *nullFileOperations) StateSave(stateSinkObject state.Sink) { + n.beforeSave() +} + +func (n *nullFileOperations) afterLoad() {} + +// +checklocksignore +func (n *nullFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func (zd *zeroDevice) StateTypeName() string { + return "pkg/sentry/fs/dev.zeroDevice" +} + +func (zd *zeroDevice) StateFields() []string { + return []string{ + "nullDevice", + } +} + +func (zd *zeroDevice) beforeSave() {} + +// +checklocksignore +func (zd *zeroDevice) StateSave(stateSinkObject state.Sink) { + zd.beforeSave() + stateSinkObject.Save(0, &zd.nullDevice) +} + +func (zd *zeroDevice) afterLoad() {} + +// +checklocksignore +func (zd *zeroDevice) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &zd.nullDevice) +} + +func (z *zeroFileOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.zeroFileOperations" +} + +func (z *zeroFileOperations) StateFields() []string { + return []string{} +} + +func (z *zeroFileOperations) beforeSave() {} + +// +checklocksignore +func (z *zeroFileOperations) StateSave(stateSinkObject state.Sink) { + z.beforeSave() +} + +func (z *zeroFileOperations) afterLoad() {} + +// +checklocksignore +func (z *zeroFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func (r *randomDevice) StateTypeName() string { + return "pkg/sentry/fs/dev.randomDevice" +} + +func (r *randomDevice) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (r *randomDevice) beforeSave() {} + +// +checklocksignore +func (r *randomDevice) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.InodeSimpleAttributes) +} + +func (r *randomDevice) afterLoad() {} + +// +checklocksignore +func (r *randomDevice) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.InodeSimpleAttributes) +} + +func (r *randomFileOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.randomFileOperations" +} + +func (r *randomFileOperations) StateFields() []string { + return []string{} +} + +func (r *randomFileOperations) beforeSave() {} + +// +checklocksignore +func (r *randomFileOperations) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *randomFileOperations) afterLoad() {} + +// +checklocksignore +func (r *randomFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func (t *ttyInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.ttyInodeOperations" +} + +func (t *ttyInodeOperations) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (t *ttyInodeOperations) beforeSave() {} + +// +checklocksignore +func (t *ttyInodeOperations) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.InodeSimpleAttributes) +} + +func (t *ttyInodeOperations) afterLoad() {} + +// +checklocksignore +func (t *ttyInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.InodeSimpleAttributes) +} + +func (t *ttyFileOperations) StateTypeName() string { + return "pkg/sentry/fs/dev.ttyFileOperations" +} + +func (t *ttyFileOperations) StateFields() []string { + return []string{} +} + +func (t *ttyFileOperations) beforeSave() {} + +// +checklocksignore +func (t *ttyFileOperations) StateSave(stateSinkObject state.Sink) { + t.beforeSave() +} + +func (t *ttyFileOperations) afterLoad() {} + +// +checklocksignore +func (t *ttyFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*filesystem)(nil)) + state.Register((*fullDevice)(nil)) + state.Register((*fullFileOperations)(nil)) + state.Register((*netTunInodeOperations)(nil)) + state.Register((*netTunFileOperations)(nil)) + state.Register((*nullDevice)(nil)) + state.Register((*nullFileOperations)(nil)) + state.Register((*zeroDevice)(nil)) + state.Register((*zeroFileOperations)(nil)) + state.Register((*randomDevice)(nil)) + state.Register((*randomFileOperations)(nil)) + state.Register((*ttyInodeOperations)(nil)) + state.Register((*ttyFileOperations)(nil)) +} diff --git a/pkg/sentry/fs/dirent_cache_test.go b/pkg/sentry/fs/dirent_cache_test.go deleted file mode 100644 index 395c879f5..000000000 --- a/pkg/sentry/fs/dirent_cache_test.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" -) - -func TestDirentCache(t *testing.T) { - const maxSize = 5 - - c := NewDirentCache(maxSize) - - // Size starts at 0. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Create a Dirent d. - d := NewNegativeDirent("") - - // c does not contain d. - if got, want := c.contains(d), false; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add d to the cache. - c.Add(d) - - // Size is now 1. - if got, want := c.Size(), uint64(1); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add maxSize-1 more elements. d should be oldest element. - for i := 0; i < maxSize-1; i++ { - c.Add(NewNegativeDirent("")) - } - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // "Bump" d to the front by re-adding it. - c.Add(d) - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add maxSize-1 more elements. d should again be oldest element. - for i := 0; i < maxSize-1; i++ { - c.Add(NewNegativeDirent("")) - } - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add one more element, which will bump d from the cache. - c.Add(NewNegativeDirent("")) - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c does not contain d. - if got, want := c.contains(d), false; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Invalidating causes size to be 0 and list to be empty. - c.Invalidate() - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - if got, want := c.list.Empty(), true; got != want { - t.Errorf("c.list.Empty() got %v, want %v", got, want) - } - - // Fill cache with maxSize dirents. - for i := 0; i < maxSize; i++ { - c.Add(NewNegativeDirent("")) - } -} - -func TestDirentCacheLimiter(t *testing.T) { - const ( - globalMaxSize = 5 - maxSize = 3 - ) - - limit := NewDirentCacheLimiter(globalMaxSize) - c1 := NewDirentCache(maxSize) - c1.limit = limit - c2 := NewDirentCache(maxSize) - c2.limit = limit - - // Create a Dirent d. - d := NewNegativeDirent("") - - // Add d to the cache. - c1.Add(d) - if got, want := c1.Size(), uint64(1); got != want { - t.Errorf("c1.Size() got %v, want %v", got, want) - } - - // Add maxSize-1 more elements. d should be oldest element. - for i := 0; i < maxSize-1; i++ { - c1.Add(NewNegativeDirent("")) - } - if got, want := c1.Size(), uint64(maxSize); got != want { - t.Errorf("c1.Size() got %v, want %v", got, want) - } - - // Check that d is still there. - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Fill up the other cache, it will start dropping old entries from the cache - // when the global limit is reached. - for i := 0; i < maxSize; i++ { - c2.Add(NewNegativeDirent("")) - } - - // Check is what's remaining from global max. - if got, want := c2.Size(), globalMaxSize-maxSize; int(got) != want { - t.Errorf("c2.Size() got %v, want %v", got, want) - } - - // Check that d was not dropped. - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Add an entry that will eventually be dropped. Check is done later... - drop := NewNegativeDirent("") - c1.Add(drop) - - // Check that d is bumped to front even when global limit is reached. - c1.Add(d) - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Add 2 more element and check that: - // - d is still in the list: to verify that d was bumped - // - d2/d3 are in the list: older entries are dropped when global limit is - // reached. - // - drop is not in the list: indeed older elements are dropped. - d2 := NewNegativeDirent("") - c1.Add(d2) - d3 := NewNegativeDirent("") - c1.Add(d3) - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - if got, want := c1.contains(d2), true; got != want { - t.Errorf("c1.contains(d2) got %v want %v", got, want) - } - if got, want := c1.contains(d3), true; got != want { - t.Errorf("c1.contains(d3) got %v want %v", got, want) - } - if got, want := c1.contains(drop), false; got != want { - t.Errorf("c1.contains(drop) got %v want %v", got, want) - } - - // Drop all entries from one cache. The other will be allowed to grow. - c1.Invalidate() - c2.Add(NewNegativeDirent("")) - if got, want := c2.Size(), uint64(maxSize); got != want { - t.Errorf("c2.Size() got %v, want %v", got, want) - } -} - -// TestNilDirentCache tests that a nil cache supports all cache operations, but -// treats them as noop. -func TestNilDirentCache(t *testing.T) { - // Create a nil cache. - var c *DirentCache - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Add. - c.Add(NewNegativeDirent("")) - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Remove. - c.Remove(NewNegativeDirent("")) - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Invalidate. - c.Invalidate() - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } -} diff --git a/pkg/sentry/fs/dirent_list.go b/pkg/sentry/fs/dirent_list.go new file mode 100644 index 000000000..00dd39146 --- /dev/null +++ b/pkg/sentry/fs/dirent_list.go @@ -0,0 +1,221 @@ +package fs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type direntElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (direntElementMapper) linkerFor(elem *Dirent) *Dirent { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type direntList struct { + head *Dirent + tail *Dirent +} + +// Reset resets list l to the empty state. +func (l *direntList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *direntList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *direntList) Front() *Dirent { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *direntList) Back() *Dirent { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *direntList) Len() (count int) { + for e := l.Front(); e != nil; e = (direntElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *direntList) PushFront(e *Dirent) { + linker := direntElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + direntElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *direntList) PushBack(e *Dirent) { + linker := direntElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + direntElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *direntList) PushBackList(m *direntList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + direntElementMapper{}.linkerFor(l.tail).SetNext(m.head) + direntElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *direntList) InsertAfter(b, e *Dirent) { + bLinker := direntElementMapper{}.linkerFor(b) + eLinker := direntElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + direntElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *direntList) InsertBefore(a, e *Dirent) { + aLinker := direntElementMapper{}.linkerFor(a) + eLinker := direntElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + direntElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *direntList) Remove(e *Dirent) { + linker := direntElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + direntElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + direntElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type direntEntry struct { + next *Dirent + prev *Dirent +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *direntEntry) Next() *Dirent { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *direntEntry) Prev() *Dirent { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *direntEntry) SetNext(elem *Dirent) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *direntEntry) SetPrev(elem *Dirent) { + e.prev = elem +} diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go deleted file mode 100644 index e2b66f357..000000000 --- a/pkg/sentry/fs/dirent_refs_test.go +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" -) - -func newMockDirInode(ctx context.Context, cache *DirentCache) *Inode { - return NewMockInode(ctx, NewMockMountSource(cache), StableAttr{Type: Directory}) -} - -func TestWalkPositive(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(ctx, newMockDirInode(ctx, nil), "root") - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - name := "d" - d, err := root.walk(ctx, root, name, false) - if err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - - if got := root.ReadRefs(); got != 2 { - t.Fatalf("root has a ref count of %d, want %d", got, 2) - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1) - } - - d.DecRef(ctx) - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - if got := d.ReadRefs(); got != 0 { - t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0) - } - - root.flush(ctx) - - if got := len(root.children); got != 0 { - t.Fatalf("root has %d children, want %d", got, 0) - } -} - -func TestWalkNegative(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(ctx, NewEmptyDir(ctx, nil), "root") - mn := root.Inode.InodeOperations.(*mockInodeOperationsLookupNegative) - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - name := "d" - for i := 0; i < 100; i++ { - _, err := root.walk(ctx, root, name, false) - if err != unix.ENOENT { - t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, unix.ENOENT) - } - } - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - if got := len(root.children); got != 1 { - t.Fatalf("root has %d children, want %d", got, 1) - } - - w, ok := root.children[name] - if !ok { - t.Fatalf("root wants child at %q", name) - } - - child := w.Get() - if child == nil { - t.Fatalf("root wants to resolve weak reference") - } - - if !child.(*Dirent).IsNegative() { - t.Fatalf("root found positive child at %q, want negative", name) - } - - if got := child.(*Dirent).ReadRefs(); got != 2 { - t.Fatalf("child has a ref count of %d, want %d", got, 2) - } - - child.DecRef(ctx) - - if got := child.(*Dirent).ReadRefs(); got != 1 { - t.Fatalf("child has a ref count of %d, want %d", got, 1) - } - - if got := len(root.children); got != 1 { - t.Fatalf("root has %d children, want %d", got, 1) - } - - root.DecRef(ctx) - - if got := root.ReadRefs(); got != 0 { - t.Fatalf("root has a ref count of %d, want %d", got, 0) - } - - AsyncBarrier() - - if got := mn.releaseCalled; got != true { - t.Fatalf("root.Close was called %v, want true", got) - } -} - -type mockInodeOperationsLookupNegative struct { - *MockInodeOperations - releaseCalled bool -} - -func NewEmptyDir(ctx context.Context, cache *DirentCache) *Inode { - m := NewMockMountSource(cache) - return NewInode(ctx, &mockInodeOperationsLookupNegative{ - MockInodeOperations: NewMockInodeOperations(ctx), - }, m, StableAttr{Type: Directory}) -} - -func (m *mockInodeOperationsLookupNegative) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) { - return NewNegativeDirent(p), nil -} - -func (m *mockInodeOperationsLookupNegative) Release(context.Context) { - m.releaseCalled = true -} - -func TestHashNegativeToPositive(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(ctx, NewEmptyDir(ctx, nil), "root") - - name := "d" - _, err := root.walk(ctx, root, name, false) - if err != unix.ENOENT { - t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, unix.ENOENT) - } - - if got := root.exists(ctx, root, name); got != false { - t.Fatalf("got %q exists, want does not exist", name) - } - - f, err := root.Create(ctx, root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(%q, _), got error %v, want nil", name, err) - } - d := f.Dirent - - if d.IsNegative() { - t.Fatalf("got negative Dirent, want positive") - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("child %q has a ref count of %d, want %d", name, got, 1) - } - - if got := root.ReadRefs(); got != 2 { - t.Fatalf("root has a ref count of %d, want %d", got, 2) - } - - if got := len(root.children); got != 1 { - t.Fatalf("got %d children, want %d", got, 1) - } - - w, ok := root.children[name] - if !ok { - t.Fatalf("failed to find weak reference to %q", name) - } - - child := w.Get() - if child == nil { - t.Fatalf("want to resolve weak reference") - } - - if child.(*Dirent) != d { - t.Fatalf("got foreign child") - } -} - -func TestRevalidate(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - for _, test := range []struct { - // desc is the test's description. - desc string - - // Whether to make negative Dirents. - makeNegative bool - }{ - { - desc: "Revalidate negative Dirent", - makeNegative: true, - }, - { - desc: "Revalidate positive Dirent", - makeNegative: false, - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - root := NewDirent(ctx, NewMockInodeRevalidate(ctx, test.makeNegative), "root") - - name := "d" - d1, err := root.walk(ctx, root, name, false) - if !test.makeNegative && err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - d2, err := root.walk(ctx, root, name, false) - if !test.makeNegative && err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - if !test.makeNegative && d1 == d2 { - t.Fatalf("revalidating walk got same *Dirent, want different") - } - if got := len(root.children); got != 1 { - t.Errorf("revalidating walk got %d children, want %d", got, 1) - } - }) - } -} - -type MockInodeOperationsRevalidate struct { - *MockInodeOperations - makeNegative bool -} - -func NewMockInodeRevalidate(ctx context.Context, makeNegative bool) *Inode { - mn := NewMockInodeOperations(ctx) - m := NewMockMountSource(nil) - m.MountSourceOperations.(*MockMountSourceOps).revalidate = true - return NewInode(ctx, &MockInodeOperationsRevalidate{MockInodeOperations: mn, makeNegative: makeNegative}, m, StableAttr{Type: Directory}) -} - -func (m *MockInodeOperationsRevalidate) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) { - if !m.makeNegative { - return m.MockInodeOperations.Lookup(ctx, dir, p) - } - return NewNegativeDirent(p), nil -} - -func TestCreateExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // root is the Dirent to create from. - root *Dirent - - // expected references on walked Dirent. - refs int64 - }{ - { - desc: "Create caching", - root: NewDirent(ctx, NewEmptyDir(ctx, NewDirentCache(1)), "root"), - refs: 2, - }, - { - desc: "Create not caching", - root: NewDirent(ctx, NewEmptyDir(ctx, nil), "root"), - refs: 1, - }, - } { - t.Run(test.desc, func(t *testing.T) { - name := "d" - f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(root, %q) failed: %v", name, err) - } - d := f.Dirent - - if got := d.ReadRefs(); got != test.refs { - t.Errorf("dirent has a ref count of %d, want %d", got, test.refs) - } - }) - } -} - -func TestRemoveExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // root is the Dirent to make and remove from. - root *Dirent - }{ - { - desc: "Remove caching", - root: NewDirent(ctx, NewEmptyDir(ctx, NewDirentCache(1)), "root"), - }, - { - desc: "Remove not caching", - root: NewDirent(ctx, NewEmptyDir(ctx, nil), "root"), - }, - } { - t.Run(test.desc, func(t *testing.T) { - name := "d" - f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(%q, _) failed: %v", name, err) - } - d := f.Dirent - - if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil { - t.Fatalf("root.Remove(root, %q) failed: %v", name, err) - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("dirent has a ref count of %d, want %d", got, 1) - } - - d.DecRef(ctx) - - test.root.flush(ctx) - - if got := len(test.root.children); got != 0 { - t.Errorf("root has %d children, want %d", got, 0) - } - }) - } -} - -func TestRenameExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - for _, test := range []struct { - // desc is the test's description. - desc string - - // cache of extra Dirent references, may be nil. - cache *DirentCache - }{ - { - desc: "Rename no caching", - cache: nil, - }, - { - desc: "Rename caching", - cache: NewDirentCache(5), - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - - dirAttr := StableAttr{Type: Directory} - - oldParent := NewDirent(ctx, NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "old_parent") - newParent := NewDirent(ctx, NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "new_parent") - - renamed, err := oldParent.Walk(ctx, oldParent, "old_child") - if err != nil { - t.Fatalf("Walk(oldParent, %q) got error %v, want nil", "old_child", err) - } - replaced, err := newParent.Walk(ctx, oldParent, "new_child") - if err != nil { - t.Fatalf("Walk(newParent, %q) got error %v, want nil", "new_child", err) - } - - if err := Rename(contexttest.RootContext(t), oldParent /*root */, oldParent, "old_child", newParent, "new_child"); err != nil { - t.Fatalf("Rename got error %v, want nil", err) - } - - oldParent.flush(ctx) - newParent.flush(ctx) - - // Expect to have only active references. - if got := renamed.ReadRefs(); got != 1 { - t.Errorf("renamed has ref count %d, want only active references %d", got, 1) - } - if got := replaced.ReadRefs(); got != 1 { - t.Errorf("replaced has ref count %d, want only active references %d", got, 1) - } - }) - } -} diff --git a/pkg/sentry/fs/event_list.go b/pkg/sentry/fs/event_list.go new file mode 100644 index 000000000..ceb97b703 --- /dev/null +++ b/pkg/sentry/fs/event_list.go @@ -0,0 +1,221 @@ +package fs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type eventElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (eventElementMapper) linkerFor(elem *Event) *Event { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type eventList struct { + head *Event + tail *Event +} + +// Reset resets list l to the empty state. +func (l *eventList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *eventList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *eventList) Front() *Event { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *eventList) Back() *Event { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *eventList) Len() (count int) { + for e := l.Front(); e != nil; e = (eventElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *eventList) PushFront(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + eventElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *eventList) PushBack(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *eventList) PushBackList(m *eventList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(m.head) + eventElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *eventList) InsertAfter(b, e *Event) { + bLinker := eventElementMapper{}.linkerFor(b) + eLinker := eventElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + eventElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *eventList) InsertBefore(a, e *Event) { + aLinker := eventElementMapper{}.linkerFor(a) + eLinker := eventElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + eventElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *eventList) Remove(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + eventElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + eventElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type eventEntry struct { + next *Event + prev *Event +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *eventEntry) Next() *Event { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *eventEntry) Prev() *Event { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *eventEntry) SetNext(elem *Event) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *eventEntry) SetPrev(elem *Event) { + e.prev = elem +} diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD deleted file mode 100644 index 9f1fe5160..000000000 --- a/pkg/sentry/fs/fdpipe/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdpipe", - srcs = [ - "pipe.go", - "pipe_opener.go", - "pipe_state.go", - ], - imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdpipe_test", - size = "small", - srcs = [ - "pipe_opener_test.go", - "pipe_test.go", - ], - library = ":fdpipe", - deps = [ - "//pkg/context", - "//pkg/errors", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/hostarch", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/usermem", - "@com_github_google_uuid//:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go new file mode 100644 index 000000000..68406ce66 --- /dev/null +++ b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package fdpipe + +import ( + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/state" +) + +func (p *pipeOperations) StateTypeName() string { + return "pkg/sentry/fs/fdpipe.pipeOperations" +} + +func (p *pipeOperations) StateFields() []string { + return []string{ + "flags", + "opener", + "readAheadBuffer", + } +} + +// +checklocksignore +func (p *pipeOperations) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var flagsValue fs.FileFlags + flagsValue = p.saveFlags() + stateSinkObject.SaveValue(0, flagsValue) + stateSinkObject.Save(1, &p.opener) + stateSinkObject.Save(2, &p.readAheadBuffer) +} + +// +checklocksignore +func (p *pipeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(1, &p.opener) + stateSourceObject.Load(2, &p.readAheadBuffer) + stateSourceObject.LoadValue(0, new(fs.FileFlags), func(y interface{}) { p.loadFlags(y.(fs.FileFlags)) }) + stateSourceObject.AfterLoad(p.afterLoad) +} + +func init() { + state.Register((*pipeOperations)(nil)) +} diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go deleted file mode 100644 index e1587288e..000000000 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ /dev/null @@ -1,522 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdpipe - -import ( - "bytes" - "fmt" - "io" - "os" - "path" - "testing" - "time" - - "github.com/google/uuid" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/usermem" -) - -type hostOpener struct { - name string -} - -func (h *hostOpener) NonBlockingOpen(_ context.Context, p fs.PermMask) (*fd.FD, error) { - var flags int - switch { - case p.Read && p.Write: - flags = unix.O_RDWR - case p.Write: - flags = unix.O_WRONLY - case p.Read: - flags = unix.O_RDONLY - default: - return nil, unix.EINVAL - } - f, err := unix.Open(h.name, flags|unix.O_NONBLOCK, 0666) - if err != nil { - return nil, err - } - return fd.New(f), nil -} - -func pipename() string { - return fmt.Sprintf(path.Join(os.TempDir(), "test-named-pipe-%s"), uuid.New()) -} - -func mkpipe(name string) error { - return unix.Mknod(name, unix.S_IFIFO|0666, 0) -} - -func TestTryOpen(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // makePipe is true if the test case should create the pipe. - makePipe bool - - // flags are the fs.FileFlags used to open the pipe. - flags fs.FileFlags - - // expectFile is true if a fs.File is expected. - expectFile bool - - // err is the expected error - err error - }{ - { - desc: "FileFlags lacking Read and Write are invalid", - makePipe: false, - flags: fs.FileFlags{}, /* bogus */ - expectFile: false, - err: unix.EINVAL, - }, - { - desc: "NonBlocking Read only error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true, NonBlocking: true}, - expectFile: false, - err: unix.ENOENT, - }, - { - desc: "NonBlocking Read only success returns immediately", - makePipe: true, - flags: fs.FileFlags{Read: true, NonBlocking: true}, - expectFile: true, - err: nil, - }, - { - desc: "NonBlocking Write only error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Write: true, NonBlocking: true}, - expectFile: false, - err: unix.ENOENT, - }, - { - desc: "NonBlocking Write only no reader error returns immediately", - makePipe: true, - flags: fs.FileFlags{Write: true, NonBlocking: true}, - expectFile: false, - err: unix.ENXIO, - }, - { - desc: "ReadWrite error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true, Write: true}, - expectFile: false, - err: unix.ENOENT, - }, - { - desc: "ReadWrite returns immediately", - makePipe: true, - flags: fs.FileFlags{Read: true, Write: true}, - expectFile: true, - err: nil, - }, - { - desc: "Blocking Write only returns open error", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Write: true}, - expectFile: false, - err: unix.ENOENT, /* from bogus perms */ - }, - { - desc: "Blocking Read only returns open error", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true}, - expectFile: false, - err: unix.ENOENT, - }, - { - desc: "Blocking Write only returns with linuxerr.ErrWouldBlock", - makePipe: true, - flags: fs.FileFlags{Write: true}, - expectFile: false, - err: linuxerr.ErrWouldBlock, - }, - { - desc: "Blocking Read only returns with linuxerr.ErrWouldBlock", - makePipe: true, - flags: fs.FileFlags{Read: true}, - expectFile: false, - err: linuxerr.ErrWouldBlock, - }, - } { - name := pipename() - if test.makePipe { - // Create the pipe. We do this per-test case to keep tests independent. - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer unix.Unlink(name) - } - - // Use a host opener to keep things simple. - opener := &hostOpener{name: name} - - pipeOpenState := &pipeOpenState{} - ctx := contexttest.Context(t) - pipeOps, err := pipeOpenState.TryOpen(ctx, opener, test.flags) - if unwrapError(err) != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - if pipeOps != nil { - // Cleanup the state of the pipe, and remove the fd from the - // fdnotifier. Sadly this needed to maintain the correctness - // of other tests because the fdnotifier is global. - pipeOps.Release(ctx) - } - continue - } - if (pipeOps != nil) != test.expectFile { - t.Errorf("%s: got non-nil file %v, want %v", test.desc, pipeOps != nil, test.expectFile) - } - if pipeOps != nil { - // Same as above. - pipeOps.Release(ctx) - } - } -} - -func TestPipeOpenUnblocksEventually(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // partnerIsReader is true if the goroutine opening the same pipe as the test case - // should open the pipe read only. Otherwise write only. This also means that the - // test case will open the pipe in the opposite way. - partnerIsReader bool - - // partnerIsBlocking is true if the goroutine opening the same pipe as the test case - // should do so without the O_NONBLOCK flag, otherwise opens the pipe with O_NONBLOCK - // until ENXIO is not returned. - partnerIsBlocking bool - }{ - { - desc: "Blocking Read with blocking writer partner opens eventually", - partnerIsReader: false, - partnerIsBlocking: true, - }, - { - desc: "Blocking Write with blocking reader partner opens eventually", - partnerIsReader: true, - partnerIsBlocking: true, - }, - { - desc: "Blocking Read with non-blocking writer partner opens eventually", - partnerIsReader: false, - partnerIsBlocking: false, - }, - { - desc: "Blocking Write with non-blocking reader partner opens eventually", - partnerIsReader: true, - partnerIsBlocking: false, - }, - } { - // Create the pipe. We do this per-test case to keep tests independent. - name := pipename() - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer unix.Unlink(name) - - // Spawn the partner. - type fderr struct { - fd int - err error - } - errch := make(chan fderr, 1) - go func() { - var flags int - if test.partnerIsReader { - flags = unix.O_RDONLY - } else { - flags = unix.O_WRONLY - } - if test.partnerIsBlocking { - fd, err := unix.Open(name, flags, 0666) - errch <- fderr{fd: fd, err: err} - } else { - var fd int - err := error(unix.ENXIO) - for err == unix.ENXIO { - fd, err = unix.Open(name, flags|unix.O_NONBLOCK, 0666) - time.Sleep(1 * time.Second) - } - errch <- fderr{fd: fd, err: err} - } - }() - - // Setup file flags for either a read only or write only open. - flags := fs.FileFlags{ - Read: !test.partnerIsReader, - Write: test.partnerIsReader, - } - - // Open the pipe in a blocking way, which should succeed eventually. - opener := &hostOpener{name: name} - ctx := contexttest.Context(t) - pipeOps, err := Open(ctx, opener, flags) - if pipeOps != nil { - // Same as TestTryOpen. - pipeOps.Release(ctx) - } - - // Check that the partner opened the file successfully. - e := <-errch - if e.err != nil { - t.Errorf("%s: partner got error %v, wanted nil", test.desc, e.err) - continue - } - // If so, then close the partner fd to avoid leaking an fd. - unix.Close(e.fd) - - // Check that our blocking open was successful. - if err != nil { - t.Errorf("%s: blocking open got error %v, wanted nil", test.desc, err) - continue - } - if pipeOps == nil { - t.Errorf("%s: blocking open got nil file, wanted non-nil", test.desc) - continue - } - } -} - -func TestCopiedReadAheadBuffer(t *testing.T) { - // Create the pipe. - name := pipename() - if err := mkpipe(name); err != nil { - t.Fatalf("failed to make host pipe: %v", err) - } - defer unix.Unlink(name) - - // We're taking advantage of the fact that pipes opened read only always return - // success, but internally they are not deemed "opened" until we're sure that - // another writer comes along. This means we can open the same pipe write only - // with no problems + write to it, given that opener.Open already tried to open - // the pipe RDONLY and succeeded, which we know happened if TryOpen returns - // linuxerr.ErrwouldBlock. - // - // This simulates the open(RDONLY) <-> open(WRONLY)+write race we care about, but - // does not cause our test to be racy (which would be terrible). - opener := &hostOpener{name: name} - pipeOpenState := &pipeOpenState{} - ctx := contexttest.Context(t) - pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) - if pipeOps != nil { - pipeOps.Release(ctx) - t.Fatalf("open(%s, %o) got file, want nil", name, unix.O_RDONLY) - } - if err != linuxerr.ErrWouldBlock { - t.Fatalf("open(%s, %o) got error %v, want %v", name, unix.O_RDONLY, err, linuxerr.ErrWouldBlock) - } - - // Then open the same pipe write only and write some bytes to it. The next - // time we try to open the pipe read only again via the pipeOpenState, we should - // succeed and buffer some of the bytes written. - fd, err := unix.Open(name, unix.O_WRONLY, 0666) - if err != nil { - t.Fatalf("open(%s, %o) got error %v, want nil", name, unix.O_WRONLY, err) - } - defer unix.Close(fd) - - data := []byte("hello") - if n, err := unix.Write(fd, data); n != len(data) || err != nil { - t.Fatalf("write(%v) got (%d, %v), want (%d, nil)", data, n, err, len(data)) - } - - // Try the read again, knowing that it should succeed this time. - pipeOps, err = pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) - if pipeOps == nil { - t.Fatalf("open(%s, %o) got nil file, want not nil", name, unix.O_RDONLY) - } - defer pipeOps.Release(ctx) - - if err != nil { - t.Fatalf("open(%s, %o) got error %v, want nil", name, unix.O_RDONLY, err) - } - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, pipeOps) - - // Check that the file we opened points to a pipe with a non-empty read ahead buffer. - bufsize := len(pipeOps.readAheadBuffer) - if bufsize != 1 { - t.Fatalf("read ahead buffer got %d bytes, want %d", bufsize, 1) - } - - // Now for the final test, try to read everything in, expecting to get back all of - // the bytes that were written at once. Note that in the wild there is no atomic - // read size so expecting to get all bytes from a single writer when there are - // multiple readers is a bad expectation. - buf := make([]byte, len(data)) - ioseq := usermem.BytesIOSequence(buf) - n, err := pipeOps.Read(ctx, file, ioseq, 0) - if err != nil { - t.Fatalf("read request got error %v, want nil", err) - } - if n != int64(len(data)) { - t.Fatalf("read request got %d bytes, want %d", n, len(data)) - } - if !bytes.Equal(buf, data) { - t.Errorf("read request got bytes [%v], want [%v]", buf, data) - } -} - -func TestPipeHangup(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // flags control how we open our end of the pipe and must be read - // only or write only. They also dicate how a coordinating partner - // fd is opened, which is their inverse (read only -> write only, etc). - flags fs.FileFlags - - // hangupSelf if true causes the test case to close our end of the pipe - // and causes hangup errors to be asserted on our coordinating partner's - // fd. If hangupSelf is false, then our partner's fd is closed and the - // hangup errors are expected on our end of the pipe. - hangupSelf bool - }{ - { - desc: "Read only gets hangup error", - flags: fs.FileFlags{Read: true}, - }, - { - desc: "Write only gets hangup error", - flags: fs.FileFlags{Write: true}, - }, - { - desc: "Read only generates hangup error", - flags: fs.FileFlags{Read: true}, - hangupSelf: true, - }, - { - desc: "Write only generates hangup error", - flags: fs.FileFlags{Write: true}, - hangupSelf: true, - }, - } { - if test.flags.Read == test.flags.Write { - t.Errorf("%s: test requires a single reader or writer", test.desc) - continue - } - - // Create the pipe. We do this per-test case to keep tests independent. - name := pipename() - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer unix.Unlink(name) - - // Fire off a partner routine which tries to open the same pipe blocking, - // which will synchronize with us. The channel allows us to get back the - // fd once we expect this partner routine to succeed, so we can manifest - // hangup events more directly. - fdchan := make(chan int, 1) - go func() { - // Be explicit about the flags to protect the test from - // misconfiguration. - var flags int - if test.flags.Read { - flags = unix.O_WRONLY - } else { - flags = unix.O_RDONLY - } - fd, err := unix.Open(name, flags, 0666) - if err != nil { - t.Logf("Open(%q, %o, 0666) partner failed: %v", name, flags, err) - } - fdchan <- fd - }() - - // Open our end in a blocking way to ensure that we coordinate. - opener := &hostOpener{name: name} - ctx := contexttest.Context(t) - pipeOps, err := Open(ctx, opener, test.flags) - if err != nil { - t.Errorf("%s: Open got error %v, want nil", test.desc, err) - continue - } - // Don't defer file.DecRef here because that causes the hangup we're - // trying to test for. - - // Expect the partner routine to have coordinated with us and get back - // its open fd. - f := <-fdchan - if f < 0 { - t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f) - pipeOps.Release(ctx) - continue - } - - if test.hangupSelf { - // Hangup self and assert that our partner got the expected hangup - // error. - pipeOps.Release(ctx) - - if test.flags.Read { - // Partner is writer. - assertWriterHungup(t, test.desc, fd.NewReadWriter(f)) - } else { - // Partner is reader. - assertReaderHungup(t, test.desc, fd.NewReadWriter(f)) - } - } else { - // Hangup our partner and expect us to get the hangup error. - unix.Close(f) - defer pipeOps.Release(ctx) - - if test.flags.Read { - assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file) - } else { - assertWriterHungup(t, test.desc, pipeOps.(*pipeOperations).file) - } - } - } -} - -func assertReaderHungup(t *testing.T, desc string, reader io.Reader) bool { - // Drain the pipe completely, it might have crap in it, but expect EOF eventually. - var err error - for err == nil { - _, err = reader.Read(make([]byte, 10)) - } - if err != io.EOF { - t.Errorf("%s: read from self after hangup got error %v, want %v", desc, err, io.EOF) - return false - } - return true -} - -func assertWriterHungup(t *testing.T, desc string, writer io.Writer) bool { - if _, err := writer.Write([]byte("hello")); !linuxerr.Equals(linuxerr.EPIPE, unwrapError(err)) { - t.Errorf("%s: write to self after hangup got error %v, want %v", desc, err, linuxerr.EPIPE) - return false - } - return true -} diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go deleted file mode 100644 index 63900e766..000000000 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ /dev/null @@ -1,513 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdpipe - -import ( - "bytes" - "io" - "os" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/errors" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/usermem" -) - -func singlePipeFD() (int, error) { - fds := make([]int, 2) - if err := unix.Pipe(fds); err != nil { - return -1, err - } - unix.Close(fds[1]) - return fds[0], nil -} - -func singleDirFD() (int, error) { - return unix.Open(os.TempDir(), unix.O_RDONLY, 0666) -} - -func mockPipeDirent(t *testing.T) *fs.Dirent { - ctx := contexttest.Context(t) - node := fs.NewMockInodeOperations(ctx) - node.UAttr = fs.UnstableAttr{ - Perms: fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, - }, - } - inode := fs.NewInode(ctx, node, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - BlockSize: hostarch.PageSize, - }) - return fs.NewDirent(ctx, inode, "") -} - -func TestNewPipe(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // getfd generates the fd to pass to newPipeOperations. - getfd func() (int, error) - - // flags are the fs.FileFlags passed to newPipeOperations. - flags fs.FileFlags - - // readAheadBuffer is the buffer passed to newPipeOperations. - readAheadBuffer []byte - - // err is the expected error. - err error - }{ - { - desc: "Cannot make new pipe from bad fd", - getfd: func() (int, error) { return -1, nil }, - err: unix.EINVAL, - }, - { - desc: "Cannot make new pipe from non-pipe fd", - getfd: singleDirFD, - err: unix.EINVAL, - }, - { - desc: "Can make new pipe from pipe fd", - getfd: singlePipeFD, - flags: fs.FileFlags{Read: true}, - readAheadBuffer: []byte("hello"), - }, - } { - gfd, err := test.getfd() - if err != nil { - t.Errorf("%s: getfd got (%d, %v), want (fd, nil)", test.desc, gfd, err) - continue - } - f := fd.New(gfd) - - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, test.flags, f, test.readAheadBuffer) - if p != nil { - // This is necessary to remove the fd from the global fd notifier. - defer p.Release(ctx) - } else { - // If there is no p to DecRef on, because newPipeOperations failed, then the - // file still needs to be closed. - defer f.Close() - } - - if err != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - continue - } - // Check the state of the pipe given that it was successfully opened. - if err == nil { - if p == nil { - t.Errorf("%s: got nil pipe and nil error, want (pipe, nil)", test.desc) - continue - } - if flags := p.flags; test.flags != flags { - t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags) - continue - } - if len(test.readAheadBuffer) != len(p.readAheadBuffer) { - t.Errorf("%s: got read ahead buffer length %d, want %d", test.desc, len(p.readAheadBuffer), len(test.readAheadBuffer)) - continue - } - fileFlags, _, errno := unix.Syscall(unix.SYS_FCNTL, uintptr(p.file.FD()), unix.F_GETFL, 0) - if errno != 0 { - t.Errorf("%s: failed to get file flags for fd %d, got %v, want 0", test.desc, p.file.FD(), errno) - continue - } - if fileFlags&unix.O_NONBLOCK == 0 { - t.Errorf("%s: pipe is blocking, expected non-blocking", test.desc) - continue - } - if !fdnotifier.HasFD(int32(f.FD())) { - t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD()) - } - } - } -} - -func TestPipeDestruction(t *testing.T) { - fds := make([]int, 2) - if err := unix.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - f := fd.New(fds[0]) - - // We don't care about the other end, just use the read end. - unix.Close(fds[1]) - - // Test the read end, but it doesn't really matter which. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, f, nil) - if err != nil { - f.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Drop our only reference, which should trigger the destructor. - p.Release(ctx) - - if fdnotifier.HasFD(int32(fds[0])) { - t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0]) - } - if p.file != nil { - t.Errorf("after DecRef got file, want nil") - } -} - -type Seek struct{} - -type ReadDir struct{} - -type Writev struct { - Src usermem.IOSequence -} - -type Readv struct { - Dst usermem.IOSequence -} - -type Fsync struct{} - -func TestPipeRequest(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // request to execute. - context interface{} - - // flags determines whether to use the read or write end - // of the pipe, for this test it can only be Read or Write. - flags fs.FileFlags - - // keepOpenPartner if false closes the other end of the pipe, - // otherwise this is delayed until the end of the test. - keepOpenPartner bool - - // expected error - err error - }{ - { - desc: "ReadDir on pipe returns ENOTDIR", - context: &ReadDir{}, - err: linuxerr.ENOTDIR, - }, - { - desc: "Fsync on pipe returns EINVAL", - context: &Fsync{}, - err: linuxerr.EINVAL, - }, - { - desc: "Seek on pipe returns ESPIPE", - context: &Seek{}, - err: linuxerr.ESPIPE, - }, - { - desc: "Readv on pipe from empty buffer returns nil", - context: &Readv{Dst: usermem.BytesIOSequence(nil)}, - flags: fs.FileFlags{Read: true}, - }, - { - desc: "Readv on pipe from non-empty buffer and closed partner returns EOF", - context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, - flags: fs.FileFlags{Read: true}, - err: io.EOF, - }, - { - desc: "Readv on pipe from non-empty buffer and open partner returns EWOULDBLOCK", - context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, - flags: fs.FileFlags{Read: true}, - keepOpenPartner: true, - err: linuxerr.ErrWouldBlock, - }, - { - desc: "Writev on pipe from empty buffer returns nil", - context: &Writev{Src: usermem.BytesIOSequence(nil)}, - flags: fs.FileFlags{Write: true}, - }, - { - desc: "Writev on pipe from non-empty buffer and closed partner returns EPIPE", - context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))}, - flags: fs.FileFlags{Write: true}, - err: linuxerr.EPIPE, - }, - { - desc: "Writev on pipe from non-empty buffer and open partner succeeds", - context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))}, - flags: fs.FileFlags{Write: true}, - keepOpenPartner: true, - }, - } { - if test.flags.Read && test.flags.Write { - panic("both read and write not supported for this test") - } - - fds := make([]int, 2) - if err := unix.Pipe(fds); err != nil { - t.Errorf("%s: failed to create pipes: got %v, want nil", test.desc, err) - continue - } - - // Configure the fd and partner fd based on the file flags. - testFd, partnerFd := fds[0], fds[1] - if test.flags.Write { - testFd, partnerFd = fds[1], fds[0] - } - - // Configure closing the fds. - if test.keepOpenPartner { - defer unix.Close(partnerFd) - } else { - unix.Close(partnerFd) - } - - // Create the pipe. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, test.flags, fd.New(testFd), nil) - if err != nil { - t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err) - } - defer p.Release(ctx) - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Issue request via the appropriate function. - switch c := test.context.(type) { - case *Seek: - _, err = p.Seek(ctx, file, 0, 0) - case *ReadDir: - _, err = p.Readdir(ctx, file, nil) - case *Readv: - _, err = p.Read(ctx, file, c.Dst, 0) - case *Writev: - _, err = p.Write(ctx, file, c.Src, 0) - case *Fsync: - err = p.Fsync(ctx, file, 0, fs.FileMaxOffset, fs.SyncAll) - default: - t.Errorf("%s: unknown request type %T", test.desc, test.context) - } - - if linuxErr, ok := test.err.(*errors.Error); ok { - if !linuxerr.Equals(linuxErr, unwrapError(err)) { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - } - } else if test.err != unwrapError(err) { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - } - } -} - -func TestPipeReadAheadBuffer(t *testing.T) { - fds := make([]int, 2) - if err := unix.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - rfile := fd.New(fds[0]) - - // Eventually close the write end, which is not wrapped in a pipe object. - defer unix.Close(fds[1]) - - // Write some bytes to this end. - data := []byte("world") - if n, err := unix.Write(fds[1], data); n != len(data) || err != nil { - rfile.Close() - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data)) - } - // Close the write end immediately, we don't care about it. - - buffered := []byte("hello ") - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, buffered) - if err != nil { - rfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - defer p.Release(ctx) - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - // In total we expect to read data + buffered. - total := append(buffered, data...) - - buf := make([]byte, len(total)) - iov := usermem.BytesIOSequence(buf) - n, err := p.Read(contexttest.Context(t), file, iov, 0) - if err != nil { - t.Fatalf("read request got error %v, want nil", err) - } - if n != int64(len(total)) { - t.Fatalf("read request got %d bytes, want %d", n, len(total)) - } - if !bytes.Equal(buf, total) { - t.Errorf("read request got bytes [%v], want [%v]", buf, total) - } -} - -// This is very important for pipes in general because they can return -// EWOULDBLOCK and for those that block they must continue until they have read -// all of the data (and report it as such). -func TestPipeReadsAccumulate(t *testing.T) { - fds := make([]int, 2) - if err := unix.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - rfile := fd.New(fds[0]) - - // Eventually close the write end, it doesn't depend on a pipe object. - defer unix.Close(fds[1]) - - // Get a new read only pipe reference. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, nil) - if err != nil { - rfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Don't forget to remove the fd from the fd notifier. Otherwise other tests will - // likely be borked, because it's global :( - defer p.Release(ctx) - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Write some some bytes to the pipe. - data := []byte("some message") - if n, err := unix.Write(fds[1], data); n != len(data) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data)) - } - - // Construct a segment vec that is a bit more than we have written so we - // trigger an EWOULDBLOCK. - wantBytes := len(data) + 1 - readBuffer := make([]byte, wantBytes) - iov := usermem.BytesIOSequence(readBuffer) - n, err := p.Read(ctx, file, iov, 0) - total := n - iov = iov.DropFirst64(n) - if err != linuxerr.ErrWouldBlock { - t.Fatalf("Readv got error %v, want %v", err, linuxerr.ErrWouldBlock) - } - - // Write a few more bytes to allow us to read more/accumulate. - extra := []byte("extra") - if n, err := unix.Write(fds[1], extra); n != len(extra) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(extra)) - } - - // This time, using the same request, we should not block. - n, err = p.Read(ctx, file, iov, 0) - total += n - if err != nil { - t.Fatalf("Readv got error %v, want nil", err) - } - - // Assert that the result we got back is cumulative. - if total != int64(wantBytes) { - t.Fatalf("Readv sequence got %d bytes, want %d", total, wantBytes) - } - - if want := append(data, extra[0]); !bytes.Equal(readBuffer, want) { - t.Errorf("Readv sequence got %v, want %v", readBuffer, want) - } -} - -// Same as TestReadsAccumulate. -func TestPipeWritesAccumulate(t *testing.T) { - fds := make([]int, 2) - if err := unix.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - wfile := fd.New(fds[1]) - - // Eventually close the read end, it doesn't depend on a pipe object. - defer unix.Close(fds[0]) - - // Get a new write only pipe reference. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Write: true}, wfile, nil) - if err != nil { - wfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Don't forget to remove the fd from the fd notifier. Otherwise other tests - // will likely be borked, because it's global :( - defer p.Release(ctx) - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) - - pipeSize, _, errno := unix.Syscall(unix.SYS_FCNTL, uintptr(wfile.FD()), unix.F_GETPIPE_SZ, 0) - if errno != 0 { - t.Fatalf("fcntl(F_GETPIPE_SZ) failed: %v", errno) - } - t.Logf("Pipe buffer size: %d", pipeSize) - - // Construct a segment vec that is larger than the pipe size to trigger an - // EWOULDBLOCK. - wantBytes := int(pipeSize) * 2 - writeBuffer := make([]byte, wantBytes) - for i := 0; i < wantBytes; i++ { - writeBuffer[i] = 'a' - } - iov := usermem.BytesIOSequence(writeBuffer) - n, err := p.Write(ctx, file, iov, 0) - if err != linuxerr.ErrWouldBlock { - t.Fatalf("Writev got error %v, want %v", err, linuxerr.ErrWouldBlock) - } - if n != int64(pipeSize) { - t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize) - } - total := n - iov = iov.DropFirst64(n) - - // Read the entire pipe buf size to make space for the second half. - readBuffer := make([]byte, n) - if n, err := unix.Read(fds[0], readBuffer); n != len(readBuffer) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(readBuffer)) - } - if !bytes.Equal(readBuffer, writeBuffer[:len(readBuffer)]) { - t.Fatalf("wrong data read from pipe, got: %v, want: %v", readBuffer, writeBuffer) - } - - // This time we should not block. - n, err = p.Write(ctx, file, iov, 0) - if err != nil { - t.Fatalf("Writev got error %v, want nil", err) - } - if n != int64(pipeSize) { - t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize) - } - total += n - - // Assert that the result we got back is cumulative. - if total != int64(wantBytes) { - t.Fatalf("Writev sequence got %d bytes, want %d", total, wantBytes) - } -} diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go deleted file mode 100644 index 1971cc680..000000000 --- a/pkg/sentry/fs/file_overlay_test.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -func TestReaddir(t *testing.T) { - ctx := contexttest.Context(t) - ctx = &rootContext{ - Context: ctx, - root: fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root"), - } - for _, test := range []struct { - // Test description. - desc string - - // Lookup parameters. - dir *fs.Inode - - // Want from lookup. - err error - names []string - }{ - { - desc: "no upper, lower has entries", - dir: fs.NewTestOverlayDir(ctx, - nil, /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - {name: "b"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper has entries, no lower", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - {name: "b"}, - }, nil), /* upper */ - nil, /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper and lower, entries combine", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "b"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper and lower, entries combine, none are masked", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "c"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "c"}, - }, - { - desc: "upper and lower, entries combine, upper masks some of lower", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "b"}, /* will be masked */ - {name: "c"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "c"}, - }, - } { - t.Run(test.desc, func(t *testing.T) { - openDir, err := test.dir.GetFile(ctx, fs.NewDirent(ctx, test.dir, "stub"), fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile got error %v, want nil", err) - } - stubSerializer := &fs.CollectEntriesSerializer{} - err = openDir.Readdir(ctx, stubSerializer) - if err != test.err { - t.Fatalf("Readdir got error %v, want nil", err) - } - if err != nil { - return - } - if !reflect.DeepEqual(stubSerializer.Order, test.names) { - t.Errorf("Readdir got names %v, want %v", stubSerializer.Order, test.names) - } - }) - } -} - -func TestReaddirRevalidation(t *testing.T) { - ctx := contexttest.Context(t) - ctx = &rootContext{ - Context: ctx, - root: fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root"), - } - - // Create an overlay with two directories, each with one file. - upper := newTestRamfsDir(ctx, []dirContent{{name: "a"}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: "b"}}, nil) - overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */) - - // Get a handle to the dirent in the upper filesystem so that we can - // modify it without going through the dirent. - upperDir := upper.InodeOperations.(*dir).InodeOperations.(*ramfs.Dir) - - // Check that overlay returns the files from both upper and lower. - openDir, err := overlay.GetFile(ctx, fs.NewDirent(ctx, overlay, "stub"), fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile got error %v, want nil", err) - } - ser := &fs.CollectEntriesSerializer{} - if err := openDir.Readdir(ctx, ser); err != nil { - t.Fatalf("Readdir got error %v, want nil", err) - } - got, want := ser.Order, []string{".", "..", "a", "b"} - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Remove "a" from the upper and add "c". - if err := upperDir.Remove(ctx, upper, "a"); err != nil { - t.Fatalf("error removing child: %v", err) - } - upperDir.AddChild(ctx, "c", fs.NewInode(ctx, fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermissions{}, 0), - upper.MountSource, fs.StableAttr{Type: fs.RegularFile})) - - // Seek to beginning of the directory and do the readdir again. - if _, err := openDir.Seek(ctx, fs.SeekSet, 0); err != nil { - t.Fatalf("error seeking to beginning of dir: %v", err) - } - ser = &fs.CollectEntriesSerializer{} - if err := openDir.Readdir(ctx, ser); err != nil { - t.Fatalf("Readdir got error %v, want nil", err) - } - - // Readdir should return the updated children. - got, want = ser.Order, []string{".", "..", "b", "c"} - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } -} - -type rootContext struct { - context.Context - root *fs.Dirent -} - -// Value implements context.Context. -func (r *rootContext) Value(key interface{}) interface{} { - switch key { - case fs.CtxRoot: - r.root.IncRef() - return r.root - default: - return r.Context.Value(key) - } -} diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD deleted file mode 100644 index a8000e010..000000000 --- a/pkg/sentry/fs/filetest/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "filetest", - testonly = 1, - srcs = ["filetest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go deleted file mode 100644 index ec3d3f96c..000000000 --- a/pkg/sentry/fs/filetest/filetest.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package filetest provides a test implementation of an fs.File. -package filetest - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TestFileOperations is an implementation of the File interface. It provides all -// required methods. -type TestFileOperations struct { - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` -} - -// NewTestFile creates and initializes a new test file. -func NewTestFile(tb testing.TB) *fs.File { - ctx := contexttest.Context(tb) - dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "test") - return fs.NewFile(ctx, dirent, fs.FileFlags{}, &TestFileOperations{}) -} - -// Read just fails the request. -func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("TestFileOperations.Read not implemented") -} - -// Write just fails the request. -func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("TestFileOperations.Write not implemented") -} diff --git a/pkg/sentry/fs/fs_state_autogen.go b/pkg/sentry/fs/fs_state_autogen.go new file mode 100644 index 000000000..62ab34d6f --- /dev/null +++ b/pkg/sentry/fs/fs_state_autogen.go @@ -0,0 +1,1208 @@ +// automatically generated by stateify. + +package fs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *StableAttr) StateTypeName() string { + return "pkg/sentry/fs.StableAttr" +} + +func (s *StableAttr) StateFields() []string { + return []string{ + "Type", + "DeviceID", + "InodeID", + "BlockSize", + "DeviceFileMajor", + "DeviceFileMinor", + } +} + +func (s *StableAttr) beforeSave() {} + +// +checklocksignore +func (s *StableAttr) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Type) + stateSinkObject.Save(1, &s.DeviceID) + stateSinkObject.Save(2, &s.InodeID) + stateSinkObject.Save(3, &s.BlockSize) + stateSinkObject.Save(4, &s.DeviceFileMajor) + stateSinkObject.Save(5, &s.DeviceFileMinor) +} + +func (s *StableAttr) afterLoad() {} + +// +checklocksignore +func (s *StableAttr) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Type) + stateSourceObject.Load(1, &s.DeviceID) + stateSourceObject.Load(2, &s.InodeID) + stateSourceObject.Load(3, &s.BlockSize) + stateSourceObject.Load(4, &s.DeviceFileMajor) + stateSourceObject.Load(5, &s.DeviceFileMinor) +} + +func (ua *UnstableAttr) StateTypeName() string { + return "pkg/sentry/fs.UnstableAttr" +} + +func (ua *UnstableAttr) StateFields() []string { + return []string{ + "Size", + "Usage", + "Perms", + "Owner", + "AccessTime", + "ModificationTime", + "StatusChangeTime", + "Links", + } +} + +func (ua *UnstableAttr) beforeSave() {} + +// +checklocksignore +func (ua *UnstableAttr) StateSave(stateSinkObject state.Sink) { + ua.beforeSave() + stateSinkObject.Save(0, &ua.Size) + stateSinkObject.Save(1, &ua.Usage) + stateSinkObject.Save(2, &ua.Perms) + stateSinkObject.Save(3, &ua.Owner) + stateSinkObject.Save(4, &ua.AccessTime) + stateSinkObject.Save(5, &ua.ModificationTime) + stateSinkObject.Save(6, &ua.StatusChangeTime) + stateSinkObject.Save(7, &ua.Links) +} + +func (ua *UnstableAttr) afterLoad() {} + +// +checklocksignore +func (ua *UnstableAttr) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ua.Size) + stateSourceObject.Load(1, &ua.Usage) + stateSourceObject.Load(2, &ua.Perms) + stateSourceObject.Load(3, &ua.Owner) + stateSourceObject.Load(4, &ua.AccessTime) + stateSourceObject.Load(5, &ua.ModificationTime) + stateSourceObject.Load(6, &ua.StatusChangeTime) + stateSourceObject.Load(7, &ua.Links) +} + +func (a *AttrMask) StateTypeName() string { + return "pkg/sentry/fs.AttrMask" +} + +func (a *AttrMask) StateFields() []string { + return []string{ + "Type", + "DeviceID", + "InodeID", + "BlockSize", + "Size", + "Usage", + "Perms", + "UID", + "GID", + "AccessTime", + "ModificationTime", + "StatusChangeTime", + "Links", + } +} + +func (a *AttrMask) beforeSave() {} + +// +checklocksignore +func (a *AttrMask) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.Type) + stateSinkObject.Save(1, &a.DeviceID) + stateSinkObject.Save(2, &a.InodeID) + stateSinkObject.Save(3, &a.BlockSize) + stateSinkObject.Save(4, &a.Size) + stateSinkObject.Save(5, &a.Usage) + stateSinkObject.Save(6, &a.Perms) + stateSinkObject.Save(7, &a.UID) + stateSinkObject.Save(8, &a.GID) + stateSinkObject.Save(9, &a.AccessTime) + stateSinkObject.Save(10, &a.ModificationTime) + stateSinkObject.Save(11, &a.StatusChangeTime) + stateSinkObject.Save(12, &a.Links) +} + +func (a *AttrMask) afterLoad() {} + +// +checklocksignore +func (a *AttrMask) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.Type) + stateSourceObject.Load(1, &a.DeviceID) + stateSourceObject.Load(2, &a.InodeID) + stateSourceObject.Load(3, &a.BlockSize) + stateSourceObject.Load(4, &a.Size) + stateSourceObject.Load(5, &a.Usage) + stateSourceObject.Load(6, &a.Perms) + stateSourceObject.Load(7, &a.UID) + stateSourceObject.Load(8, &a.GID) + stateSourceObject.Load(9, &a.AccessTime) + stateSourceObject.Load(10, &a.ModificationTime) + stateSourceObject.Load(11, &a.StatusChangeTime) + stateSourceObject.Load(12, &a.Links) +} + +func (p *PermMask) StateTypeName() string { + return "pkg/sentry/fs.PermMask" +} + +func (p *PermMask) StateFields() []string { + return []string{ + "Read", + "Write", + "Execute", + } +} + +func (p *PermMask) beforeSave() {} + +// +checklocksignore +func (p *PermMask) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.Read) + stateSinkObject.Save(1, &p.Write) + stateSinkObject.Save(2, &p.Execute) +} + +func (p *PermMask) afterLoad() {} + +// +checklocksignore +func (p *PermMask) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.Read) + stateSourceObject.Load(1, &p.Write) + stateSourceObject.Load(2, &p.Execute) +} + +func (f *FilePermissions) StateTypeName() string { + return "pkg/sentry/fs.FilePermissions" +} + +func (f *FilePermissions) StateFields() []string { + return []string{ + "User", + "Group", + "Other", + "Sticky", + "SetUID", + "SetGID", + } +} + +func (f *FilePermissions) beforeSave() {} + +// +checklocksignore +func (f *FilePermissions) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.User) + stateSinkObject.Save(1, &f.Group) + stateSinkObject.Save(2, &f.Other) + stateSinkObject.Save(3, &f.Sticky) + stateSinkObject.Save(4, &f.SetUID) + stateSinkObject.Save(5, &f.SetGID) +} + +func (f *FilePermissions) afterLoad() {} + +// +checklocksignore +func (f *FilePermissions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.User) + stateSourceObject.Load(1, &f.Group) + stateSourceObject.Load(2, &f.Other) + stateSourceObject.Load(3, &f.Sticky) + stateSourceObject.Load(4, &f.SetUID) + stateSourceObject.Load(5, &f.SetGID) +} + +func (f *FileOwner) StateTypeName() string { + return "pkg/sentry/fs.FileOwner" +} + +func (f *FileOwner) StateFields() []string { + return []string{ + "UID", + "GID", + } +} + +func (f *FileOwner) beforeSave() {} + +// +checklocksignore +func (f *FileOwner) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.UID) + stateSinkObject.Save(1, &f.GID) +} + +func (f *FileOwner) afterLoad() {} + +// +checklocksignore +func (f *FileOwner) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.UID) + stateSourceObject.Load(1, &f.GID) +} + +func (d *DentAttr) StateTypeName() string { + return "pkg/sentry/fs.DentAttr" +} + +func (d *DentAttr) StateFields() []string { + return []string{ + "Type", + "InodeID", + } +} + +func (d *DentAttr) beforeSave() {} + +// +checklocksignore +func (d *DentAttr) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.Type) + stateSinkObject.Save(1, &d.InodeID) +} + +func (d *DentAttr) afterLoad() {} + +// +checklocksignore +func (d *DentAttr) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.Type) + stateSourceObject.Load(1, &d.InodeID) +} + +func (s *SortedDentryMap) StateTypeName() string { + return "pkg/sentry/fs.SortedDentryMap" +} + +func (s *SortedDentryMap) StateFields() []string { + return []string{ + "names", + "entries", + } +} + +func (s *SortedDentryMap) beforeSave() {} + +// +checklocksignore +func (s *SortedDentryMap) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.names) + stateSinkObject.Save(1, &s.entries) +} + +func (s *SortedDentryMap) afterLoad() {} + +// +checklocksignore +func (s *SortedDentryMap) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.names) + stateSourceObject.Load(1, &s.entries) +} + +func (d *Dirent) StateTypeName() string { + return "pkg/sentry/fs.Dirent" +} + +func (d *Dirent) StateFields() []string { + return []string{ + "AtomicRefCount", + "userVisible", + "Inode", + "name", + "parent", + "deleted", + "mounted", + "children", + } +} + +// +checklocksignore +func (d *Dirent) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + var childrenValue map[string]*Dirent + childrenValue = d.saveChildren() + stateSinkObject.SaveValue(7, childrenValue) + stateSinkObject.Save(0, &d.AtomicRefCount) + stateSinkObject.Save(1, &d.userVisible) + stateSinkObject.Save(2, &d.Inode) + stateSinkObject.Save(3, &d.name) + stateSinkObject.Save(4, &d.parent) + stateSinkObject.Save(5, &d.deleted) + stateSinkObject.Save(6, &d.mounted) +} + +// +checklocksignore +func (d *Dirent) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.AtomicRefCount) + stateSourceObject.Load(1, &d.userVisible) + stateSourceObject.Load(2, &d.Inode) + stateSourceObject.Load(3, &d.name) + stateSourceObject.Load(4, &d.parent) + stateSourceObject.Load(5, &d.deleted) + stateSourceObject.Load(6, &d.mounted) + stateSourceObject.LoadValue(7, new(map[string]*Dirent), func(y interface{}) { d.loadChildren(y.(map[string]*Dirent)) }) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (c *DirentCache) StateTypeName() string { + return "pkg/sentry/fs.DirentCache" +} + +func (c *DirentCache) StateFields() []string { + return []string{ + "maxSize", + "limit", + } +} + +func (c *DirentCache) beforeSave() {} + +// +checklocksignore +func (c *DirentCache) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + if !state.IsZeroValue(&c.currentSize) { + state.Failf("currentSize is %#v, expected zero", &c.currentSize) + } + if !state.IsZeroValue(&c.list) { + state.Failf("list is %#v, expected zero", &c.list) + } + stateSinkObject.Save(0, &c.maxSize) + stateSinkObject.Save(1, &c.limit) +} + +func (c *DirentCache) afterLoad() {} + +// +checklocksignore +func (c *DirentCache) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.maxSize) + stateSourceObject.Load(1, &c.limit) +} + +func (d *DirentCacheLimiter) StateTypeName() string { + return "pkg/sentry/fs.DirentCacheLimiter" +} + +func (d *DirentCacheLimiter) StateFields() []string { + return []string{ + "max", + } +} + +func (d *DirentCacheLimiter) beforeSave() {} + +// +checklocksignore +func (d *DirentCacheLimiter) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + if !state.IsZeroValue(&d.count) { + state.Failf("count is %#v, expected zero", &d.count) + } + stateSinkObject.Save(0, &d.max) +} + +func (d *DirentCacheLimiter) afterLoad() {} + +// +checklocksignore +func (d *DirentCacheLimiter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.max) +} + +func (l *direntList) StateTypeName() string { + return "pkg/sentry/fs.direntList" +} + +func (l *direntList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *direntList) beforeSave() {} + +// +checklocksignore +func (l *direntList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *direntList) afterLoad() {} + +// +checklocksignore +func (l *direntList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *direntEntry) StateTypeName() string { + return "pkg/sentry/fs.direntEntry" +} + +func (e *direntEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *direntEntry) beforeSave() {} + +// +checklocksignore +func (e *direntEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *direntEntry) afterLoad() {} + +// +checklocksignore +func (e *direntEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (l *eventList) StateTypeName() string { + return "pkg/sentry/fs.eventList" +} + +func (l *eventList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *eventList) beforeSave() {} + +// +checklocksignore +func (l *eventList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *eventList) afterLoad() {} + +// +checklocksignore +func (l *eventList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *eventEntry) StateTypeName() string { + return "pkg/sentry/fs.eventEntry" +} + +func (e *eventEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *eventEntry) beforeSave() {} + +// +checklocksignore +func (e *eventEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *eventEntry) afterLoad() {} + +// +checklocksignore +func (e *eventEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (f *File) StateTypeName() string { + return "pkg/sentry/fs.File" +} + +func (f *File) StateFields() []string { + return []string{ + "AtomicRefCount", + "UniqueID", + "Dirent", + "flags", + "async", + "FileOperations", + "offset", + } +} + +// +checklocksignore +func (f *File) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.AtomicRefCount) + stateSinkObject.Save(1, &f.UniqueID) + stateSinkObject.Save(2, &f.Dirent) + stateSinkObject.Save(3, &f.flags) + stateSinkObject.Save(4, &f.async) + stateSinkObject.Save(5, &f.FileOperations) + stateSinkObject.Save(6, &f.offset) +} + +// +checklocksignore +func (f *File) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.AtomicRefCount) + stateSourceObject.Load(1, &f.UniqueID) + stateSourceObject.Load(2, &f.Dirent) + stateSourceObject.Load(3, &f.flags) + stateSourceObject.Load(4, &f.async) + stateSourceObject.LoadWait(5, &f.FileOperations) + stateSourceObject.Load(6, &f.offset) + stateSourceObject.AfterLoad(f.afterLoad) +} + +func (f *overlayFileOperations) StateTypeName() string { + return "pkg/sentry/fs.overlayFileOperations" +} + +func (f *overlayFileOperations) StateFields() []string { + return []string{ + "upper", + "lower", + "dirCursor", + } +} + +func (f *overlayFileOperations) beforeSave() {} + +// +checklocksignore +func (f *overlayFileOperations) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.upper) + stateSinkObject.Save(1, &f.lower) + stateSinkObject.Save(2, &f.dirCursor) +} + +func (f *overlayFileOperations) afterLoad() {} + +// +checklocksignore +func (f *overlayFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.upper) + stateSourceObject.Load(1, &f.lower) + stateSourceObject.Load(2, &f.dirCursor) +} + +func (omi *overlayMappingIdentity) StateTypeName() string { + return "pkg/sentry/fs.overlayMappingIdentity" +} + +func (omi *overlayMappingIdentity) StateFields() []string { + return []string{ + "AtomicRefCount", + "id", + "overlayFile", + } +} + +func (omi *overlayMappingIdentity) beforeSave() {} + +// +checklocksignore +func (omi *overlayMappingIdentity) StateSave(stateSinkObject state.Sink) { + omi.beforeSave() + stateSinkObject.Save(0, &omi.AtomicRefCount) + stateSinkObject.Save(1, &omi.id) + stateSinkObject.Save(2, &omi.overlayFile) +} + +func (omi *overlayMappingIdentity) afterLoad() {} + +// +checklocksignore +func (omi *overlayMappingIdentity) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &omi.AtomicRefCount) + stateSourceObject.Load(1, &omi.id) + stateSourceObject.Load(2, &omi.overlayFile) +} + +func (m *MountSourceFlags) StateTypeName() string { + return "pkg/sentry/fs.MountSourceFlags" +} + +func (m *MountSourceFlags) StateFields() []string { + return []string{ + "ReadOnly", + "NoAtime", + "ForcePageCache", + "NoExec", + } +} + +func (m *MountSourceFlags) beforeSave() {} + +// +checklocksignore +func (m *MountSourceFlags) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.ReadOnly) + stateSinkObject.Save(1, &m.NoAtime) + stateSinkObject.Save(2, &m.ForcePageCache) + stateSinkObject.Save(3, &m.NoExec) +} + +func (m *MountSourceFlags) afterLoad() {} + +// +checklocksignore +func (m *MountSourceFlags) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.ReadOnly) + stateSourceObject.Load(1, &m.NoAtime) + stateSourceObject.Load(2, &m.ForcePageCache) + stateSourceObject.Load(3, &m.NoExec) +} + +func (f *FileFlags) StateTypeName() string { + return "pkg/sentry/fs.FileFlags" +} + +func (f *FileFlags) StateFields() []string { + return []string{ + "Direct", + "NonBlocking", + "DSync", + "Sync", + "Append", + "Read", + "Write", + "Pread", + "Pwrite", + "Directory", + "Async", + "LargeFile", + "NonSeekable", + "Truncate", + } +} + +func (f *FileFlags) beforeSave() {} + +// +checklocksignore +func (f *FileFlags) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.Direct) + stateSinkObject.Save(1, &f.NonBlocking) + stateSinkObject.Save(2, &f.DSync) + stateSinkObject.Save(3, &f.Sync) + stateSinkObject.Save(4, &f.Append) + stateSinkObject.Save(5, &f.Read) + stateSinkObject.Save(6, &f.Write) + stateSinkObject.Save(7, &f.Pread) + stateSinkObject.Save(8, &f.Pwrite) + stateSinkObject.Save(9, &f.Directory) + stateSinkObject.Save(10, &f.Async) + stateSinkObject.Save(11, &f.LargeFile) + stateSinkObject.Save(12, &f.NonSeekable) + stateSinkObject.Save(13, &f.Truncate) +} + +func (f *FileFlags) afterLoad() {} + +// +checklocksignore +func (f *FileFlags) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.Direct) + stateSourceObject.Load(1, &f.NonBlocking) + stateSourceObject.Load(2, &f.DSync) + stateSourceObject.Load(3, &f.Sync) + stateSourceObject.Load(4, &f.Append) + stateSourceObject.Load(5, &f.Read) + stateSourceObject.Load(6, &f.Write) + stateSourceObject.Load(7, &f.Pread) + stateSourceObject.Load(8, &f.Pwrite) + stateSourceObject.Load(9, &f.Directory) + stateSourceObject.Load(10, &f.Async) + stateSourceObject.Load(11, &f.LargeFile) + stateSourceObject.Load(12, &f.NonSeekable) + stateSourceObject.Load(13, &f.Truncate) +} + +func (i *Inode) StateTypeName() string { + return "pkg/sentry/fs.Inode" +} + +func (i *Inode) StateFields() []string { + return []string{ + "AtomicRefCount", + "InodeOperations", + "StableAttr", + "LockCtx", + "Watches", + "MountSource", + "overlay", + } +} + +func (i *Inode) beforeSave() {} + +// +checklocksignore +func (i *Inode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.AtomicRefCount) + stateSinkObject.Save(1, &i.InodeOperations) + stateSinkObject.Save(2, &i.StableAttr) + stateSinkObject.Save(3, &i.LockCtx) + stateSinkObject.Save(4, &i.Watches) + stateSinkObject.Save(5, &i.MountSource) + stateSinkObject.Save(6, &i.overlay) +} + +func (i *Inode) afterLoad() {} + +// +checklocksignore +func (i *Inode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.AtomicRefCount) + stateSourceObject.Load(1, &i.InodeOperations) + stateSourceObject.Load(2, &i.StableAttr) + stateSourceObject.Load(3, &i.LockCtx) + stateSourceObject.Load(4, &i.Watches) + stateSourceObject.Load(5, &i.MountSource) + stateSourceObject.Load(6, &i.overlay) +} + +func (l *LockCtx) StateTypeName() string { + return "pkg/sentry/fs.LockCtx" +} + +func (l *LockCtx) StateFields() []string { + return []string{ + "Posix", + "BSD", + } +} + +func (l *LockCtx) beforeSave() {} + +// +checklocksignore +func (l *LockCtx) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Posix) + stateSinkObject.Save(1, &l.BSD) +} + +func (l *LockCtx) afterLoad() {} + +// +checklocksignore +func (l *LockCtx) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Posix) + stateSourceObject.Load(1, &l.BSD) +} + +func (w *Watches) StateTypeName() string { + return "pkg/sentry/fs.Watches" +} + +func (w *Watches) StateFields() []string { + return []string{ + "ws", + "unlinked", + } +} + +func (w *Watches) beforeSave() {} + +// +checklocksignore +func (w *Watches) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.ws) + stateSinkObject.Save(1, &w.unlinked) +} + +func (w *Watches) afterLoad() {} + +// +checklocksignore +func (w *Watches) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.ws) + stateSourceObject.Load(1, &w.unlinked) +} + +func (i *Inotify) StateTypeName() string { + return "pkg/sentry/fs.Inotify" +} + +func (i *Inotify) StateFields() []string { + return []string{ + "id", + "events", + "scratch", + "nextWatch", + "watches", + } +} + +func (i *Inotify) beforeSave() {} + +// +checklocksignore +func (i *Inotify) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.id) + stateSinkObject.Save(1, &i.events) + stateSinkObject.Save(2, &i.scratch) + stateSinkObject.Save(3, &i.nextWatch) + stateSinkObject.Save(4, &i.watches) +} + +func (i *Inotify) afterLoad() {} + +// +checklocksignore +func (i *Inotify) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.id) + stateSourceObject.Load(1, &i.events) + stateSourceObject.Load(2, &i.scratch) + stateSourceObject.Load(3, &i.nextWatch) + stateSourceObject.Load(4, &i.watches) +} + +func (e *Event) StateTypeName() string { + return "pkg/sentry/fs.Event" +} + +func (e *Event) StateFields() []string { + return []string{ + "eventEntry", + "wd", + "mask", + "cookie", + "len", + "name", + } +} + +func (e *Event) beforeSave() {} + +// +checklocksignore +func (e *Event) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.eventEntry) + stateSinkObject.Save(1, &e.wd) + stateSinkObject.Save(2, &e.mask) + stateSinkObject.Save(3, &e.cookie) + stateSinkObject.Save(4, &e.len) + stateSinkObject.Save(5, &e.name) +} + +func (e *Event) afterLoad() {} + +// +checklocksignore +func (e *Event) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.eventEntry) + stateSourceObject.Load(1, &e.wd) + stateSourceObject.Load(2, &e.mask) + stateSourceObject.Load(3, &e.cookie) + stateSourceObject.Load(4, &e.len) + stateSourceObject.Load(5, &e.name) +} + +func (w *Watch) StateTypeName() string { + return "pkg/sentry/fs.Watch" +} + +func (w *Watch) StateFields() []string { + return []string{ + "owner", + "wd", + "target", + "unpinned", + "mask", + "pins", + } +} + +func (w *Watch) beforeSave() {} + +// +checklocksignore +func (w *Watch) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.owner) + stateSinkObject.Save(1, &w.wd) + stateSinkObject.Save(2, &w.target) + stateSinkObject.Save(3, &w.unpinned) + stateSinkObject.Save(4, &w.mask) + stateSinkObject.Save(5, &w.pins) +} + +func (w *Watch) afterLoad() {} + +// +checklocksignore +func (w *Watch) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.owner) + stateSourceObject.Load(1, &w.wd) + stateSourceObject.Load(2, &w.target) + stateSourceObject.Load(3, &w.unpinned) + stateSourceObject.Load(4, &w.mask) + stateSourceObject.Load(5, &w.pins) +} + +func (msrc *MountSource) StateTypeName() string { + return "pkg/sentry/fs.MountSource" +} + +func (msrc *MountSource) StateFields() []string { + return []string{ + "AtomicRefCount", + "MountSourceOperations", + "FilesystemType", + "Flags", + "fscache", + "direntRefs", + } +} + +func (msrc *MountSource) beforeSave() {} + +// +checklocksignore +func (msrc *MountSource) StateSave(stateSinkObject state.Sink) { + msrc.beforeSave() + stateSinkObject.Save(0, &msrc.AtomicRefCount) + stateSinkObject.Save(1, &msrc.MountSourceOperations) + stateSinkObject.Save(2, &msrc.FilesystemType) + stateSinkObject.Save(3, &msrc.Flags) + stateSinkObject.Save(4, &msrc.fscache) + stateSinkObject.Save(5, &msrc.direntRefs) +} + +func (msrc *MountSource) afterLoad() {} + +// +checklocksignore +func (msrc *MountSource) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &msrc.AtomicRefCount) + stateSourceObject.Load(1, &msrc.MountSourceOperations) + stateSourceObject.Load(2, &msrc.FilesystemType) + stateSourceObject.Load(3, &msrc.Flags) + stateSourceObject.Load(4, &msrc.fscache) + stateSourceObject.Load(5, &msrc.direntRefs) +} + +func (smo *SimpleMountSourceOperations) StateTypeName() string { + return "pkg/sentry/fs.SimpleMountSourceOperations" +} + +func (smo *SimpleMountSourceOperations) StateFields() []string { + return []string{ + "keep", + "revalidate", + "cacheReaddir", + } +} + +func (smo *SimpleMountSourceOperations) beforeSave() {} + +// +checklocksignore +func (smo *SimpleMountSourceOperations) StateSave(stateSinkObject state.Sink) { + smo.beforeSave() + stateSinkObject.Save(0, &smo.keep) + stateSinkObject.Save(1, &smo.revalidate) + stateSinkObject.Save(2, &smo.cacheReaddir) +} + +func (smo *SimpleMountSourceOperations) afterLoad() {} + +// +checklocksignore +func (smo *SimpleMountSourceOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &smo.keep) + stateSourceObject.Load(1, &smo.revalidate) + stateSourceObject.Load(2, &smo.cacheReaddir) +} + +func (o *overlayMountSourceOperations) StateTypeName() string { + return "pkg/sentry/fs.overlayMountSourceOperations" +} + +func (o *overlayMountSourceOperations) StateFields() []string { + return []string{ + "upper", + "lower", + } +} + +func (o *overlayMountSourceOperations) beforeSave() {} + +// +checklocksignore +func (o *overlayMountSourceOperations) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.upper) + stateSinkObject.Save(1, &o.lower) +} + +func (o *overlayMountSourceOperations) afterLoad() {} + +// +checklocksignore +func (o *overlayMountSourceOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.upper) + stateSourceObject.Load(1, &o.lower) +} + +func (ofs *overlayFilesystem) StateTypeName() string { + return "pkg/sentry/fs.overlayFilesystem" +} + +func (ofs *overlayFilesystem) StateFields() []string { + return []string{} +} + +func (ofs *overlayFilesystem) beforeSave() {} + +// +checklocksignore +func (ofs *overlayFilesystem) StateSave(stateSinkObject state.Sink) { + ofs.beforeSave() +} + +func (ofs *overlayFilesystem) afterLoad() {} + +// +checklocksignore +func (ofs *overlayFilesystem) StateLoad(stateSourceObject state.Source) { +} + +func (m *Mount) StateTypeName() string { + return "pkg/sentry/fs.Mount" +} + +func (m *Mount) StateFields() []string { + return []string{ + "ID", + "ParentID", + "root", + "previous", + } +} + +func (m *Mount) beforeSave() {} + +// +checklocksignore +func (m *Mount) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.ID) + stateSinkObject.Save(1, &m.ParentID) + stateSinkObject.Save(2, &m.root) + stateSinkObject.Save(3, &m.previous) +} + +func (m *Mount) afterLoad() {} + +// +checklocksignore +func (m *Mount) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.ID) + stateSourceObject.Load(1, &m.ParentID) + stateSourceObject.Load(2, &m.root) + stateSourceObject.Load(3, &m.previous) +} + +func (mns *MountNamespace) StateTypeName() string { + return "pkg/sentry/fs.MountNamespace" +} + +func (mns *MountNamespace) StateFields() []string { + return []string{ + "AtomicRefCount", + "userns", + "root", + "mounts", + "mountID", + } +} + +func (mns *MountNamespace) beforeSave() {} + +// +checklocksignore +func (mns *MountNamespace) StateSave(stateSinkObject state.Sink) { + mns.beforeSave() + stateSinkObject.Save(0, &mns.AtomicRefCount) + stateSinkObject.Save(1, &mns.userns) + stateSinkObject.Save(2, &mns.root) + stateSinkObject.Save(3, &mns.mounts) + stateSinkObject.Save(4, &mns.mountID) +} + +func (mns *MountNamespace) afterLoad() {} + +// +checklocksignore +func (mns *MountNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mns.AtomicRefCount) + stateSourceObject.Load(1, &mns.userns) + stateSourceObject.Load(2, &mns.root) + stateSourceObject.Load(3, &mns.mounts) + stateSourceObject.Load(4, &mns.mountID) +} + +func (o *overlayEntry) StateTypeName() string { + return "pkg/sentry/fs.overlayEntry" +} + +func (o *overlayEntry) StateFields() []string { + return []string{ + "lowerExists", + "lower", + "mappings", + "upper", + "dirCache", + } +} + +func (o *overlayEntry) beforeSave() {} + +// +checklocksignore +func (o *overlayEntry) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.lowerExists) + stateSinkObject.Save(1, &o.lower) + stateSinkObject.Save(2, &o.mappings) + stateSinkObject.Save(3, &o.upper) + stateSinkObject.Save(4, &o.dirCache) +} + +func (o *overlayEntry) afterLoad() {} + +// +checklocksignore +func (o *overlayEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.lowerExists) + stateSourceObject.Load(1, &o.lower) + stateSourceObject.Load(2, &o.mappings) + stateSourceObject.Load(3, &o.upper) + stateSourceObject.Load(4, &o.dirCache) +} + +func init() { + state.Register((*StableAttr)(nil)) + state.Register((*UnstableAttr)(nil)) + state.Register((*AttrMask)(nil)) + state.Register((*PermMask)(nil)) + state.Register((*FilePermissions)(nil)) + state.Register((*FileOwner)(nil)) + state.Register((*DentAttr)(nil)) + state.Register((*SortedDentryMap)(nil)) + state.Register((*Dirent)(nil)) + state.Register((*DirentCache)(nil)) + state.Register((*DirentCacheLimiter)(nil)) + state.Register((*direntList)(nil)) + state.Register((*direntEntry)(nil)) + state.Register((*eventList)(nil)) + state.Register((*eventEntry)(nil)) + state.Register((*File)(nil)) + state.Register((*overlayFileOperations)(nil)) + state.Register((*overlayMappingIdentity)(nil)) + state.Register((*MountSourceFlags)(nil)) + state.Register((*FileFlags)(nil)) + state.Register((*Inode)(nil)) + state.Register((*LockCtx)(nil)) + state.Register((*Watches)(nil)) + state.Register((*Inotify)(nil)) + state.Register((*Event)(nil)) + state.Register((*Watch)(nil)) + state.Register((*MountSource)(nil)) + state.Register((*SimpleMountSourceOperations)(nil)) + state.Register((*overlayMountSourceOperations)(nil)) + state.Register((*overlayFilesystem)(nil)) + state.Register((*Mount)(nil)) + state.Register((*MountNamespace)(nil)) + state.Register((*overlayEntry)(nil)) +} diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD deleted file mode 100644 index 1a59800ea..000000000 --- a/pkg/sentry/fs/fsutil/BUILD +++ /dev/null @@ -1,118 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "dirty_set_impl", - out = "dirty_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - }, - package = "fsutil", - prefix = "Dirty", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.MappableRange", - "Value": "DirtyInfo", - "Functions": "dirtySetFunctions", - }, -) - -go_template_instance( - name = "frame_ref_set_impl", - out = "frame_ref_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - }, - package = "fsutil", - prefix = "FrameRef", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.FileRange", - "Value": "uint64", - "Functions": "FrameRefSetFunctions", - }, -) - -go_template_instance( - name = "file_range_set_impl", - out = "file_range_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - }, - package = "fsutil", - prefix = "FileRange", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.MappableRange", - "Value": "uint64", - "Functions": "FileRangeSetFunctions", - }, -) - -go_library( - name = "fsutil", - srcs = [ - "dirty_set.go", - "dirty_set_impl.go", - "file.go", - "file_range_set.go", - "file_range_set_impl.go", - "frame_ref_set.go", - "frame_ref_set_impl.go", - "fsutil.go", - "host_file_mapper.go", - "host_file_mapper_state.go", - "host_file_mapper_unsafe.go", - "host_mappable.go", - "inode.go", - "inode_cached.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/state", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fsutil_test", - size = "small", - srcs = [ - "dirty_set_test.go", - "inode_cached_test.go", - ], - library = ":fsutil", - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/safemem", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/fsutil/README.md b/pkg/sentry/fs/fsutil/README.md deleted file mode 100644 index 8be367334..000000000 --- a/pkg/sentry/fs/fsutil/README.md +++ /dev/null @@ -1,207 +0,0 @@ -This package provides utilities for implementing virtual filesystem objects. - -[TOC] - -## Page cache - -`CachingInodeOperations` implements a page cache for files that cannot use the -host page cache. Normally these are files that store their data in a remote -filesystem. This also applies to files that are accessed on a platform that does -not support directly memory mapping host file descriptors (e.g. the ptrace -platform). - -An `CachingInodeOperations` buffers regions of a single file into memory. It is -owned by an `fs.Inode`, the in-memory representation of a file (all open file -descriptors are backed by an `fs.Inode`). The `fs.Inode` provides operations for -reading memory into an `CachingInodeOperations`, to represent the contents of -the file in-memory, and for writing memory out, to relieve memory pressure on -the kernel and to synchronize in-memory changes to filesystems. - -An `CachingInodeOperations` enables readable and/or writable memory access to -file content. Files can be mapped shared or private, see mmap(2). When a file is -mapped shared, changes to the file via write(2) and truncate(2) are reflected in -the shared memory region. Conversely, when the shared memory region is modified, -changes to the file are visible via read(2). Multiple shared mappings of the -same file are coherent with each other. This is consistent with Linux. - -When a file is mapped private, updates to the mapped memory are not visible to -other memory mappings. Updates to the mapped memory are also not reflected in -the file content as seen by read(2). If the file is changed after a private -mapping is created, for instance by write(2), the change to the file may or may -not be reflected in the private mapping. This is consistent with Linux. - -An `CachingInodeOperations` keeps track of ranges of memory that were modified -(or "dirtied"). When the file is explicitly synced via fsync(2), only the dirty -ranges are written out to the filesystem. Any error returned indicates a failure -to write all dirty memory of an `CachingInodeOperations` to the filesystem. In -this case the filesystem may be in an inconsistent state. The same operation can -be performed on the shared memory itself using msync(2). If neither fsync(2) nor -msync(2) is performed, then the dirty memory is written out in accordance with -the `CachingInodeOperations` eviction strategy (see below) and there is no -guarantee that memory will be written out successfully in full. - -### Memory allocation and eviction - -An `CachingInodeOperations` implements the following allocation and eviction -strategy: - -- Memory is allocated and brought up to date with the contents of a file when - a region of mapped memory is accessed (or "faulted on"). - -- Dirty memory is written out to filesystems when an fsync(2) or msync(2) - operation is performed on a memory mapped file, for all memory mapped files - when saved, and/or when there are no longer any memory mappings of a range - of a file, see munmap(2). As the latter implies, in the absence of a panic - or SIGKILL, dirty memory is written out for all memory mapped files when an - application exits. - -- Memory is freed when there are no longer any memory mappings of a range of a - file (e.g. when an application exits). This behavior is consistent with - Linux for shared memory that has been locked via mlock(2). - -Notably, memory is not allocated for read(2) or write(2) operations. This means -that reads and writes to the file are only accelerated by an -`CachingInodeOperations` if the file being read or written has been memory -mapped *and* if the shared memory has been accessed at the region being read or -written. This diverges from Linux which buffers memory into a page cache on -read(2) proactively (i.e. readahead) and delays writing it out to filesystems on -write(2) (i.e. writeback). The absence of these optimizations is not visible to -applications beyond less than optimal performance when repeatedly reading and/or -writing to same region of a file. See [Future Work](#future-work) for plans to -implement these optimizations. - -Additionally, memory held by `CachingInodeOperationss` is currently unbounded in -size. An `CachingInodeOperations` does not write out dirty memory and free it -under system memory pressure. This can cause pathological memory usage. - -When memory is written back, an `CachingInodeOperations` may write regions of -shared memory that were never modified. This is due to the strategy of -minimizing page faults (see below) and handling only a subset of memory write -faults. In the absence of an application or sentry crash, it is guaranteed that -if a region of shared memory was written to, it is written back to a filesystem. - -### Life of a shared memory mapping - -A file is memory mapped via mmap(2). For example, if `A` is an address, an -application may execute: - -``` -mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); -``` - -This creates a shared mapping of fd that reflects 4k of the contents of fd -starting at offset 0, accessible at address `A`. This in turn creates a virtual -memory area region ("vma") which indicates that [`A`, `A`+0x1000) is now a valid -address range for this application to access. - -At this point, memory has not been allocated in the file's -`CachingInodeOperations`. It is also the case that the address range [`A`, -`A`+0x1000) has not been mapped on the host on behalf of the application. If the -application then tries to modify 8 bytes of the shared memory: - -``` -char buffer[] = "aaaaaaaa"; -memcpy(A, buffer, 8); -``` - -The host then sends a `SIGSEGV` to the sentry because the address range [`A`, -`A`+8) is not mapped on the host. The `SIGSEGV` indicates that the memory was -accessed writable. The sentry looks up the vma associated with [`A`, `A`+8), -finds the file that was mapped and its `CachingInodeOperations`. It then calls -`CachingInodeOperations.Translate` which allocates memory to back [`A`, `A`+8). -It may choose to allocate more memory (i.e. do "readahead") to minimize -subsequent faults. - -Memory that is allocated comes from a host tmpfs file (see -`pgalloc.MemoryFile`). The host tmpfs file memory is brought up to date with the -contents of the mapped file on its filesystem. The region of the host tmpfs file -that reflects the mapped file is then mapped into the host address space of the -application so that subsequent memory accesses do not repeatedly generate a -`SIGSEGV`. - -The range that was allocated, including any extra memory allocation to minimize -faults, is marked dirty due to the write fault. This overcounts dirty memory if -the extra memory allocated is never modified. - -To make the scenario more interesting, imagine that this application spawns -another process and maps the same file in the exact same way: - -``` -mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); -``` - -Imagine that this process then tries to modify the file again but with only 4 -bytes: - -``` -char buffer[] = "bbbb"; -memcpy(A, buffer, 4); -``` - -Since the first process has already mapped and accessed the same region of the -file writable, `CachingInodeOperations.Translate` is called but returns the -memory that has already been allocated rather than allocating new memory. The -address range [`A`, `A`+0x1000) reflects the same cached view of the file as the -first process sees. For example, reading 8 bytes from the file from either -process via read(2) starting at offset 0 returns a consistent "bbbbaaaa". - -When this process no longer needs the shared memory, it may do: - -``` -munmap(A, 0x1000); -``` - -At this point, the modified memory cached by the `CachingInodeOperations` is not -written back to the file because it is still in use by the first process that -mapped it. When the first process also does: - -``` -munmap(A, 0x1000); -``` - -Then the last memory mapping of the file at the range [0, 0x1000) is gone. The -file's `CachingInodeOperations` then starts writing back memory marked dirty to -the file on its filesystem. Once writing completes, regardless of whether it was -successful, the `CachingInodeOperations` frees the memory cached at the range -[0, 0x1000). - -Subsequent read(2) or write(2) operations on the file go directly to the -filesystem since there no longer exists memory for it in its -`CachingInodeOperations`. - -## Future Work - -### Page cache - -The sentry does not yet implement the readahead and writeback optimizations for -read(2) and write(2) respectively. To do so, on read(2) and/or write(2) the -sentry must ensure that memory is allocated in a page cache to read or write -into. However, the sentry cannot boundlessly allocate memory. If it did, the -host would eventually OOM-kill the sentry+application process. This means that -the sentry must implement a page cache memory allocation strategy that is -bounded by a global user or container imposed limit. When this limit is -approached, the sentry must decide from which page cache memory should be freed -so that it can allocate more memory. If it makes a poor decision, the sentry may -end up freeing and re-allocating memory to back regions of files that are -frequently used, nullifying the optimization (and in some cases causing worse -performance due to the overhead of memory allocation and general management). -This is a form of "cache thrashing". - -In Linux, much research has been done to select and implement a lightweight but -optimal page cache eviction algorithm. Linux makes use of hardware page bits to -keep track of whether memory has been accessed. The sentry does not have direct -access to hardware. Implementing a similarly lightweight and optimal page cache -eviction algorithm will need to either introduce a kernel interface to obtain -these page bits or find a suitable alternative proxy for access events. - -In Linux, readahead happens by default but is not always ideal. For instance, -for files that are not read sequentially, it would be more ideal to simply read -from only those regions of the file rather than to optimistically cache some -number of bytes ahead of the read (up to 2MB in Linux) if the bytes cached won't -be accessed. Linux implements the fadvise64(2) system call for applications to -specify that a range of a file will not be accessed sequentially. The advice bit -FADV_RANDOM turns off the readahead optimization for the given range in the -given file. However fadvise64 is rarely used by applications so Linux implements -a readahead backoff strategy if reads are not sequential. To ensure that -application performance is not degraded, the sentry must implement a similar -backoff strategy. diff --git a/pkg/sentry/fs/fsutil/dirty_set_impl.go b/pkg/sentry/fs/fsutil/dirty_set_impl.go new file mode 100644 index 000000000..2c6a10fc4 --- /dev/null +++ b/pkg/sentry/fs/fsutil/dirty_set_impl.go @@ -0,0 +1,1647 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const DirtytrackGaps = 0 + +var _ = uint8(DirtytrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type DirtydynamicGap [DirtytrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *DirtydynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *DirtydynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + DirtyminDegree = 3 + + DirtymaxDegree = 2 * DirtyminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type DirtySet struct { + root Dirtynode `state:".(*DirtySegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *DirtySet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *DirtySet) IsEmptyRange(r __generics_imported0.MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *DirtySet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *DirtySet) SpanRange(r __generics_imported0.MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *DirtySet) FirstSegment() DirtyIterator { + if s.root.nrSegments == 0 { + return DirtyIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *DirtySet) LastSegment() DirtyIterator { + if s.root.nrSegments == 0 { + return DirtyIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *DirtySet) FirstGap() DirtyGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return DirtyGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *DirtySet) LastGap() DirtyGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return DirtyGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *DirtySet) Find(key uint64) (DirtyIterator, DirtyGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return DirtyIterator{n, i}, DirtyGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return DirtyIterator{}, DirtyGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *DirtySet) FindSegment(key uint64) DirtyIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *DirtySet) LowerBoundSegment(min uint64) DirtyIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *DirtySet) UpperBoundSegment(max uint64) DirtyIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *DirtySet) FindGap(key uint64) DirtyGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *DirtySet) LowerBoundGap(min uint64) DirtyGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *DirtySet) UpperBoundGap(max uint64) DirtyGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *DirtySet) Add(r __generics_imported0.MappableRange, val DirtyInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *DirtySet) AddWithoutMerging(r __generics_imported0.MappableRange, val DirtyInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *DirtySet) Insert(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := DirtytrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (dirtySetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := DirtytrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *DirtySet) InsertWithoutMerging(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *DirtySet) InsertWithoutMergingUnchecked(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := DirtytrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return DirtyIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *DirtySet) Remove(seg DirtyIterator) DirtyGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if DirtytrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + dirtySetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if DirtytrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(DirtyGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *DirtySet) RemoveAll() { + s.root = Dirtynode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *DirtySet) RemoveRange(r __generics_imported0.MappableRange) DirtyGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *DirtySet) Merge(first, second DirtyIterator) DirtyIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *DirtySet) MergeUnchecked(first, second DirtyIterator) DirtyIterator { + if first.End() == second.Start() { + if mval, ok := (dirtySetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return DirtyIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *DirtySet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *DirtySet) MergeRange(r __generics_imported0.MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *DirtySet) MergeAdjacent(r __generics_imported0.MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *DirtySet) Split(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *DirtySet) SplitUnchecked(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) { + val1, val2 := (dirtySetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *DirtySet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *DirtySet) Isolate(seg DirtyIterator, r __generics_imported0.MappableRange) DirtyIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *DirtySet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg DirtyIterator)) DirtyGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return DirtyGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return DirtyGapIterator{} + } + } +} + +// +stateify savable +type Dirtynode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Dirtynode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap DirtydynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [DirtymaxDegree - 1]__generics_imported0.MappableRange + values [DirtymaxDegree - 1]DirtyInfo + children [DirtymaxDegree]*Dirtynode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Dirtynode) firstSegment() DirtyIterator { + for n.hasChildren { + n = n.children[0] + } + return DirtyIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Dirtynode) lastSegment() DirtyIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return DirtyIterator{n, n.nrSegments - 1} +} + +func (n *Dirtynode) prevSibling() *Dirtynode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Dirtynode) nextSibling() *Dirtynode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Dirtynode) rebalanceBeforeInsert(gap DirtyGapIterator) DirtyGapIterator { + if n.nrSegments < DirtymaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:DirtyminDegree-1], n.keys[:DirtyminDegree-1]) + copy(left.values[:DirtyminDegree-1], n.values[:DirtyminDegree-1]) + copy(right.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:]) + copy(right.values[:DirtyminDegree-1], n.values[DirtyminDegree:]) + n.keys[0], n.values[0] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1] + DirtyzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:DirtyminDegree], n.children[:DirtyminDegree]) + copy(right.children[:DirtyminDegree], n.children[DirtyminDegree:]) + DirtyzeroNodeSlice(n.children[2:]) + for i := 0; i < DirtyminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if DirtytrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < DirtyminDegree { + return DirtyGapIterator{left, gap.index} + } + return DirtyGapIterator{right, gap.index - DirtyminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:]) + copy(sibling.values[:DirtyminDegree-1], n.values[DirtyminDegree:]) + DirtyzeroValueSlice(n.values[DirtyminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:DirtyminDegree], n.children[DirtyminDegree:]) + DirtyzeroNodeSlice(n.children[DirtyminDegree:]) + for i := 0; i < DirtyminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = DirtyminDegree - 1 + + if DirtytrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < DirtyminDegree { + return gap + } + return DirtyGapIterator{sibling, gap.index - DirtyminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Dirtynode) rebalanceAfterRemove(gap DirtyGapIterator) DirtyGapIterator { + for { + if n.nrSegments >= DirtyminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if DirtytrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return DirtyGapIterator{n, 0} + } + if gap.node == n { + return DirtyGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if DirtytrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return DirtyGapIterator{n, n.nrSegments} + } + return DirtyGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return DirtyGapIterator{p, gap.index} + } + if gap.node == right { + return DirtyGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Dirtynode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = DirtyGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + dirtySetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if DirtytrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *Dirtynode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *Dirtynode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *Dirtynode) calculateMaxGapLeaf() uint64 { + max := DirtyGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (DirtyGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *Dirtynode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *Dirtynode) searchFirstLargeEnoughGap(minSize uint64) DirtyGapIterator { + if n.maxGap.Get() < minSize { + return DirtyGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := DirtyGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *Dirtynode) searchLastLargeEnoughGap(minSize uint64) DirtyGapIterator { + if n.maxGap.Get() < minSize { + return DirtyGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := DirtyGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type DirtyIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Dirtynode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg DirtyIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg DirtyIterator) Range() __generics_imported0.MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg DirtyIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg DirtyIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg DirtyIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetRange(r __generics_imported0.MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg DirtyIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg DirtyIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg DirtyIterator) Value() DirtyInfo { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg DirtyIterator) ValuePtr() *DirtyInfo { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg DirtyIterator) SetValue(val DirtyInfo) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg DirtyIterator) PrevSegment() DirtyIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return DirtyIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return DirtyIterator{} + } + return DirtysegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg DirtyIterator) NextSegment() DirtyIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return DirtyIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return DirtyIterator{} + } + return DirtysegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg DirtyIterator) PrevGap() DirtyGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return DirtyGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg DirtyIterator) NextGap() DirtyGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return DirtyGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg DirtyIterator) PrevNonEmpty() (DirtyIterator, DirtyGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return DirtyIterator{}, gap + } + return gap.PrevSegment(), DirtyGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg DirtyIterator) NextNonEmpty() (DirtyIterator, DirtyGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return DirtyIterator{}, gap + } + return gap.NextSegment(), DirtyGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type DirtyGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Dirtynode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap DirtyGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap DirtyGapIterator) Range() __generics_imported0.MappableRange { + return __generics_imported0.MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap DirtyGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return dirtySetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap DirtyGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return dirtySetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap DirtyGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap DirtyGapIterator) PrevSegment() DirtyIterator { + return DirtysegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap DirtyGapIterator) NextSegment() DirtyIterator { + return DirtysegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap DirtyGapIterator) PrevGap() DirtyGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return DirtyGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap DirtyGapIterator) NextGap() DirtyGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return DirtyGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap DirtyGapIterator) NextLargeEnoughGap(minSize uint64) DirtyGapIterator { + if DirtytrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap DirtyGapIterator) nextLargeEnoughGapHelper(minSize uint64) DirtyGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return DirtyGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap DirtyGapIterator) PrevLargeEnoughGap(minSize uint64) DirtyGapIterator { + if DirtytrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap DirtyGapIterator) prevLargeEnoughGapHelper(minSize uint64) DirtyGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return DirtyGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func DirtysegmentBeforePosition(n *Dirtynode, i int) DirtyIterator { + for i == 0 { + if n.parent == nil { + return DirtyIterator{} + } + n, i = n.parent, n.parentIndex + } + return DirtyIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func DirtysegmentAfterPosition(n *Dirtynode, i int) DirtyIterator { + for i == n.nrSegments { + if n.parent == nil { + return DirtyIterator{} + } + n, i = n.parent, n.parentIndex + } + return DirtyIterator{n, i} +} + +func DirtyzeroValueSlice(slice []DirtyInfo) { + + for i := range slice { + dirtySetFunctions{}.ClearValue(&slice[i]) + } +} + +func DirtyzeroNodeSlice(slice []*Dirtynode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *DirtySet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *Dirtynode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Dirtynode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if DirtytrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type DirtySegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []DirtyInfo +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *DirtySet) ExportSortedSlices() *DirtySegmentDataSlices { + var sds DirtySegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *DirtySet) ImportSortedSlices(sds *DirtySegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *DirtySet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.MappableRange, DirtyInfo) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *DirtySet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *DirtySet) saveRoot() *DirtySegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *DirtySet) loadRoot(sds *DirtySegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go deleted file mode 100644 index 48448c97c..000000000 --- a/pkg/sentry/fs/fsutil/dirty_set_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsutil - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/memmap" -) - -func TestDirtySet(t *testing.T) { - var set DirtySet - set.MarkDirty(memmap.MappableRange{0, 2 * hostarch.PageSize}) - set.KeepDirty(memmap.MappableRange{hostarch.PageSize, 2 * hostarch.PageSize}) - set.MarkClean(memmap.MappableRange{0, 2 * hostarch.PageSize}) - want := &DirtySegmentDataSlices{ - Start: []uint64{hostarch.PageSize}, - End: []uint64{2 * hostarch.PageSize}, - Values: []DirtyInfo{{Keep: true}}, - } - if got := set.ExportSortedSlices(); !reflect.DeepEqual(got, want) { - t.Errorf("set:\n\tgot %v,\n\twant %v", got, want) - } -} diff --git a/pkg/sentry/fs/fsutil/file_range_set_impl.go b/pkg/sentry/fs/fsutil/file_range_set_impl.go new file mode 100644 index 000000000..7568fb790 --- /dev/null +++ b/pkg/sentry/fs/fsutil/file_range_set_impl.go @@ -0,0 +1,1647 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const FileRangetrackGaps = 0 + +var _ = uint8(FileRangetrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type FileRangedynamicGap [FileRangetrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *FileRangedynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *FileRangedynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + FileRangeminDegree = 3 + + FileRangemaxDegree = 2 * FileRangeminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type FileRangeSet struct { + root FileRangenode `state:".(*FileRangeSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *FileRangeSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *FileRangeSet) IsEmptyRange(r __generics_imported0.MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *FileRangeSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *FileRangeSet) SpanRange(r __generics_imported0.MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *FileRangeSet) FirstSegment() FileRangeIterator { + if s.root.nrSegments == 0 { + return FileRangeIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *FileRangeSet) LastSegment() FileRangeIterator { + if s.root.nrSegments == 0 { + return FileRangeIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *FileRangeSet) FirstGap() FileRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return FileRangeGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *FileRangeSet) LastGap() FileRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FileRangeGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *FileRangeSet) Find(key uint64) (FileRangeIterator, FileRangeGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return FileRangeIterator{n, i}, FileRangeGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return FileRangeIterator{}, FileRangeGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *FileRangeSet) FindSegment(key uint64) FileRangeIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *FileRangeSet) LowerBoundSegment(min uint64) FileRangeIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *FileRangeSet) UpperBoundSegment(max uint64) FileRangeIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *FileRangeSet) FindGap(key uint64) FileRangeGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *FileRangeSet) LowerBoundGap(min uint64) FileRangeGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *FileRangeSet) UpperBoundGap(max uint64) FileRangeGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *FileRangeSet) Add(r __generics_imported0.MappableRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *FileRangeSet) AddWithoutMerging(r __generics_imported0.MappableRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *FileRangeSet) Insert(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (FileRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := FileRangetrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (FileRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (FileRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := FileRangetrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *FileRangeSet) InsertWithoutMerging(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *FileRangeSet) InsertWithoutMergingUnchecked(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := FileRangetrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return FileRangeIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *FileRangeSet) Remove(seg FileRangeIterator) FileRangeGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if FileRangetrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + FileRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if FileRangetrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(FileRangeGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *FileRangeSet) RemoveAll() { + s.root = FileRangenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *FileRangeSet) RemoveRange(r __generics_imported0.MappableRange) FileRangeGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *FileRangeSet) Merge(first, second FileRangeIterator) FileRangeIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *FileRangeSet) MergeUnchecked(first, second FileRangeIterator) FileRangeIterator { + if first.End() == second.Start() { + if mval, ok := (FileRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return FileRangeIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *FileRangeSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *FileRangeSet) MergeRange(r __generics_imported0.MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *FileRangeSet) MergeAdjacent(r __generics_imported0.MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *FileRangeSet) Split(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *FileRangeSet) SplitUnchecked(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) { + val1, val2 := (FileRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *FileRangeSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *FileRangeSet) Isolate(seg FileRangeIterator, r __generics_imported0.MappableRange) FileRangeIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *FileRangeSet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg FileRangeIterator)) FileRangeGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return FileRangeGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return FileRangeGapIterator{} + } + } +} + +// +stateify savable +type FileRangenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *FileRangenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap FileRangedynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [FileRangemaxDegree - 1]__generics_imported0.MappableRange + values [FileRangemaxDegree - 1]uint64 + children [FileRangemaxDegree]*FileRangenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FileRangenode) firstSegment() FileRangeIterator { + for n.hasChildren { + n = n.children[0] + } + return FileRangeIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FileRangenode) lastSegment() FileRangeIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FileRangeIterator{n, n.nrSegments - 1} +} + +func (n *FileRangenode) prevSibling() *FileRangenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *FileRangenode) nextSibling() *FileRangenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *FileRangenode) rebalanceBeforeInsert(gap FileRangeGapIterator) FileRangeGapIterator { + if n.nrSegments < FileRangemaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:FileRangeminDegree-1], n.keys[:FileRangeminDegree-1]) + copy(left.values[:FileRangeminDegree-1], n.values[:FileRangeminDegree-1]) + copy(right.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:]) + copy(right.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:]) + n.keys[0], n.values[0] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1] + FileRangezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:FileRangeminDegree], n.children[:FileRangeminDegree]) + copy(right.children[:FileRangeminDegree], n.children[FileRangeminDegree:]) + FileRangezeroNodeSlice(n.children[2:]) + for i := 0; i < FileRangeminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if FileRangetrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < FileRangeminDegree { + return FileRangeGapIterator{left, gap.index} + } + return FileRangeGapIterator{right, gap.index - FileRangeminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:]) + copy(sibling.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:]) + FileRangezeroValueSlice(n.values[FileRangeminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:FileRangeminDegree], n.children[FileRangeminDegree:]) + FileRangezeroNodeSlice(n.children[FileRangeminDegree:]) + for i := 0; i < FileRangeminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = FileRangeminDegree - 1 + + if FileRangetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < FileRangeminDegree { + return gap + } + return FileRangeGapIterator{sibling, gap.index - FileRangeminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *FileRangenode) rebalanceAfterRemove(gap FileRangeGapIterator) FileRangeGapIterator { + for { + if n.nrSegments >= FileRangeminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + FileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if FileRangetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return FileRangeGapIterator{n, 0} + } + if gap.node == n { + return FileRangeGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + FileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if FileRangetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return FileRangeGapIterator{n, n.nrSegments} + } + return FileRangeGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return FileRangeGapIterator{p, gap.index} + } + if gap.node == right { + return FileRangeGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *FileRangenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = FileRangeGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + FileRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if FileRangetrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *FileRangenode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *FileRangenode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *FileRangenode) calculateMaxGapLeaf() uint64 { + max := FileRangeGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (FileRangeGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *FileRangenode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *FileRangenode) searchFirstLargeEnoughGap(minSize uint64) FileRangeGapIterator { + if n.maxGap.Get() < minSize { + return FileRangeGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := FileRangeGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *FileRangenode) searchLastLargeEnoughGap(minSize uint64) FileRangeGapIterator { + if n.maxGap.Get() < minSize { + return FileRangeGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := FileRangeGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FileRangeIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *FileRangenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg FileRangeIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg FileRangeIterator) Range() __generics_imported0.MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg FileRangeIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg FileRangeIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg FileRangeIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg FileRangeIterator) SetRange(r __generics_imported0.MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg FileRangeIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg FileRangeIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg FileRangeIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg FileRangeIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg FileRangeIterator) Value() uint64 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg FileRangeIterator) ValuePtr() *uint64 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg FileRangeIterator) SetValue(val uint64) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg FileRangeIterator) PrevSegment() FileRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return FileRangeIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return FileRangeIterator{} + } + return FileRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg FileRangeIterator) NextSegment() FileRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return FileRangeIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return FileRangeIterator{} + } + return FileRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg FileRangeIterator) PrevGap() FileRangeGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return FileRangeGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg FileRangeIterator) NextGap() FileRangeGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return FileRangeGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg FileRangeIterator) PrevNonEmpty() (FileRangeIterator, FileRangeGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return FileRangeIterator{}, gap + } + return gap.PrevSegment(), FileRangeGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg FileRangeIterator) NextNonEmpty() (FileRangeIterator, FileRangeGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return FileRangeIterator{}, gap + } + return gap.NextSegment(), FileRangeGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FileRangeGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *FileRangenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap FileRangeGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap FileRangeGapIterator) Range() __generics_imported0.MappableRange { + return __generics_imported0.MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap FileRangeGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return FileRangeSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap FileRangeGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return FileRangeSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap FileRangeGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap FileRangeGapIterator) PrevSegment() FileRangeIterator { + return FileRangesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap FileRangeGapIterator) NextSegment() FileRangeIterator { + return FileRangesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap FileRangeGapIterator) PrevGap() FileRangeGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return FileRangeGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap FileRangeGapIterator) NextGap() FileRangeGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return FileRangeGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap FileRangeGapIterator) NextLargeEnoughGap(minSize uint64) FileRangeGapIterator { + if FileRangetrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap FileRangeGapIterator) nextLargeEnoughGapHelper(minSize uint64) FileRangeGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return FileRangeGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap FileRangeGapIterator) PrevLargeEnoughGap(minSize uint64) FileRangeGapIterator { + if FileRangetrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap FileRangeGapIterator) prevLargeEnoughGapHelper(minSize uint64) FileRangeGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return FileRangeGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func FileRangesegmentBeforePosition(n *FileRangenode, i int) FileRangeIterator { + for i == 0 { + if n.parent == nil { + return FileRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return FileRangeIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func FileRangesegmentAfterPosition(n *FileRangenode, i int) FileRangeIterator { + for i == n.nrSegments { + if n.parent == nil { + return FileRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return FileRangeIterator{n, i} +} + +func FileRangezeroValueSlice(slice []uint64) { + + for i := range slice { + FileRangeSetFunctions{}.ClearValue(&slice[i]) + } +} + +func FileRangezeroNodeSlice(slice []*FileRangenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *FileRangeSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *FileRangenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *FileRangenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if FileRangetrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type FileRangeSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []uint64 +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *FileRangeSet) ExportSortedSlices() *FileRangeSegmentDataSlices { + var sds FileRangeSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *FileRangeSet) ImportSortedSlices(sds *FileRangeSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *FileRangeSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.MappableRange, uint64) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *FileRangeSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *FileRangeSet) saveRoot() *FileRangeSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *FileRangeSet) loadRoot(sds *FileRangeSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/frame_ref_set_impl.go b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go new file mode 100644 index 000000000..6657addf4 --- /dev/null +++ b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go @@ -0,0 +1,1647 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const FrameReftrackGaps = 0 + +var _ = uint8(FrameReftrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type FrameRefdynamicGap [FrameReftrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *FrameRefdynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *FrameRefdynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + FrameRefminDegree = 3 + + FrameRefmaxDegree = 2 * FrameRefminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type FrameRefSet struct { + root FrameRefnode `state:".(*FrameRefSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *FrameRefSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *FrameRefSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *FrameRefSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *FrameRefSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *FrameRefSet) FirstSegment() FrameRefIterator { + if s.root.nrSegments == 0 { + return FrameRefIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *FrameRefSet) LastSegment() FrameRefIterator { + if s.root.nrSegments == 0 { + return FrameRefIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *FrameRefSet) FirstGap() FrameRefGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return FrameRefGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *FrameRefSet) LastGap() FrameRefGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FrameRefGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *FrameRefSet) Find(key uint64) (FrameRefIterator, FrameRefGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return FrameRefIterator{n, i}, FrameRefGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return FrameRefIterator{}, FrameRefGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *FrameRefSet) FindSegment(key uint64) FrameRefIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *FrameRefSet) LowerBoundSegment(min uint64) FrameRefIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *FrameRefSet) UpperBoundSegment(max uint64) FrameRefIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *FrameRefSet) FindGap(key uint64) FrameRefGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *FrameRefSet) LowerBoundGap(min uint64) FrameRefGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *FrameRefSet) UpperBoundGap(max uint64) FrameRefGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *FrameRefSet) Add(r __generics_imported0.FileRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *FrameRefSet) AddWithoutMerging(r __generics_imported0.FileRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *FrameRefSet) Insert(gap FrameRefGapIterator, r __generics_imported0.FileRange, val uint64) FrameRefIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (FrameRefSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := FrameReftrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (FrameRefSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (FrameRefSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := FrameReftrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *FrameRefSet) InsertWithoutMerging(gap FrameRefGapIterator, r __generics_imported0.FileRange, val uint64) FrameRefIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *FrameRefSet) InsertWithoutMergingUnchecked(gap FrameRefGapIterator, r __generics_imported0.FileRange, val uint64) FrameRefIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := FrameReftrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return FrameRefIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *FrameRefSet) Remove(seg FrameRefIterator) FrameRefGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if FrameReftrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + FrameRefSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if FrameReftrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(FrameRefGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *FrameRefSet) RemoveAll() { + s.root = FrameRefnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *FrameRefSet) RemoveRange(r __generics_imported0.FileRange) FrameRefGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *FrameRefSet) Merge(first, second FrameRefIterator) FrameRefIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *FrameRefSet) MergeUnchecked(first, second FrameRefIterator) FrameRefIterator { + if first.End() == second.Start() { + if mval, ok := (FrameRefSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return FrameRefIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *FrameRefSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *FrameRefSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *FrameRefSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *FrameRefSet) Split(seg FrameRefIterator, split uint64) (FrameRefIterator, FrameRefIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *FrameRefSet) SplitUnchecked(seg FrameRefIterator, split uint64) (FrameRefIterator, FrameRefIterator) { + val1, val2 := (FrameRefSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *FrameRefSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *FrameRefSet) Isolate(seg FrameRefIterator, r __generics_imported0.FileRange) FrameRefIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *FrameRefSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg FrameRefIterator)) FrameRefGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return FrameRefGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return FrameRefGapIterator{} + } + } +} + +// +stateify savable +type FrameRefnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *FrameRefnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap FrameRefdynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [FrameRefmaxDegree - 1]__generics_imported0.FileRange + values [FrameRefmaxDegree - 1]uint64 + children [FrameRefmaxDegree]*FrameRefnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FrameRefnode) firstSegment() FrameRefIterator { + for n.hasChildren { + n = n.children[0] + } + return FrameRefIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FrameRefnode) lastSegment() FrameRefIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FrameRefIterator{n, n.nrSegments - 1} +} + +func (n *FrameRefnode) prevSibling() *FrameRefnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *FrameRefnode) nextSibling() *FrameRefnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *FrameRefnode) rebalanceBeforeInsert(gap FrameRefGapIterator) FrameRefGapIterator { + if n.nrSegments < FrameRefmaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &FrameRefnode{ + nrSegments: FrameRefminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &FrameRefnode{ + nrSegments: FrameRefminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:FrameRefminDegree-1], n.keys[:FrameRefminDegree-1]) + copy(left.values[:FrameRefminDegree-1], n.values[:FrameRefminDegree-1]) + copy(right.keys[:FrameRefminDegree-1], n.keys[FrameRefminDegree:]) + copy(right.values[:FrameRefminDegree-1], n.values[FrameRefminDegree:]) + n.keys[0], n.values[0] = n.keys[FrameRefminDegree-1], n.values[FrameRefminDegree-1] + FrameRefzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:FrameRefminDegree], n.children[:FrameRefminDegree]) + copy(right.children[:FrameRefminDegree], n.children[FrameRefminDegree:]) + FrameRefzeroNodeSlice(n.children[2:]) + for i := 0; i < FrameRefminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if FrameReftrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < FrameRefminDegree { + return FrameRefGapIterator{left, gap.index} + } + return FrameRefGapIterator{right, gap.index - FrameRefminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[FrameRefminDegree-1], n.values[FrameRefminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &FrameRefnode{ + nrSegments: FrameRefminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:FrameRefminDegree-1], n.keys[FrameRefminDegree:]) + copy(sibling.values[:FrameRefminDegree-1], n.values[FrameRefminDegree:]) + FrameRefzeroValueSlice(n.values[FrameRefminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:FrameRefminDegree], n.children[FrameRefminDegree:]) + FrameRefzeroNodeSlice(n.children[FrameRefminDegree:]) + for i := 0; i < FrameRefminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = FrameRefminDegree - 1 + + if FrameReftrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < FrameRefminDegree { + return gap + } + return FrameRefGapIterator{sibling, gap.index - FrameRefminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *FrameRefnode) rebalanceAfterRemove(gap FrameRefGapIterator) FrameRefGapIterator { + for { + if n.nrSegments >= FrameRefminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= FrameRefminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + FrameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if FrameReftrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return FrameRefGapIterator{n, 0} + } + if gap.node == n { + return FrameRefGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= FrameRefminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + FrameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if FrameReftrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return FrameRefGapIterator{n, n.nrSegments} + } + return FrameRefGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return FrameRefGapIterator{p, gap.index} + } + if gap.node == right { + return FrameRefGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *FrameRefnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = FrameRefGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + FrameRefSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if FrameReftrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *FrameRefnode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *FrameRefnode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *FrameRefnode) calculateMaxGapLeaf() uint64 { + max := FrameRefGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (FrameRefGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *FrameRefnode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *FrameRefnode) searchFirstLargeEnoughGap(minSize uint64) FrameRefGapIterator { + if n.maxGap.Get() < minSize { + return FrameRefGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := FrameRefGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *FrameRefnode) searchLastLargeEnoughGap(minSize uint64) FrameRefGapIterator { + if n.maxGap.Get() < minSize { + return FrameRefGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := FrameRefGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FrameRefIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *FrameRefnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg FrameRefIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg FrameRefIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg FrameRefIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg FrameRefIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg FrameRefIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg FrameRefIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg FrameRefIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg FrameRefIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg FrameRefIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg FrameRefIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg FrameRefIterator) Value() uint64 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg FrameRefIterator) ValuePtr() *uint64 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg FrameRefIterator) SetValue(val uint64) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg FrameRefIterator) PrevSegment() FrameRefIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return FrameRefIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return FrameRefIterator{} + } + return FrameRefsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg FrameRefIterator) NextSegment() FrameRefIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return FrameRefIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return FrameRefIterator{} + } + return FrameRefsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg FrameRefIterator) PrevGap() FrameRefGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return FrameRefGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg FrameRefIterator) NextGap() FrameRefGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return FrameRefGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg FrameRefIterator) PrevNonEmpty() (FrameRefIterator, FrameRefGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return FrameRefIterator{}, gap + } + return gap.PrevSegment(), FrameRefGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg FrameRefIterator) NextNonEmpty() (FrameRefIterator, FrameRefGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return FrameRefIterator{}, gap + } + return gap.NextSegment(), FrameRefGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FrameRefGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *FrameRefnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap FrameRefGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap FrameRefGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap FrameRefGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return FrameRefSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap FrameRefGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return FrameRefSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap FrameRefGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap FrameRefGapIterator) PrevSegment() FrameRefIterator { + return FrameRefsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap FrameRefGapIterator) NextSegment() FrameRefIterator { + return FrameRefsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap FrameRefGapIterator) PrevGap() FrameRefGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return FrameRefGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap FrameRefGapIterator) NextGap() FrameRefGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return FrameRefGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap FrameRefGapIterator) NextLargeEnoughGap(minSize uint64) FrameRefGapIterator { + if FrameReftrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap FrameRefGapIterator) nextLargeEnoughGapHelper(minSize uint64) FrameRefGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return FrameRefGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap FrameRefGapIterator) PrevLargeEnoughGap(minSize uint64) FrameRefGapIterator { + if FrameReftrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap FrameRefGapIterator) prevLargeEnoughGapHelper(minSize uint64) FrameRefGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return FrameRefGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func FrameRefsegmentBeforePosition(n *FrameRefnode, i int) FrameRefIterator { + for i == 0 { + if n.parent == nil { + return FrameRefIterator{} + } + n, i = n.parent, n.parentIndex + } + return FrameRefIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func FrameRefsegmentAfterPosition(n *FrameRefnode, i int) FrameRefIterator { + for i == n.nrSegments { + if n.parent == nil { + return FrameRefIterator{} + } + n, i = n.parent, n.parentIndex + } + return FrameRefIterator{n, i} +} + +func FrameRefzeroValueSlice(slice []uint64) { + + for i := range slice { + FrameRefSetFunctions{}.ClearValue(&slice[i]) + } +} + +func FrameRefzeroNodeSlice(slice []*FrameRefnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *FrameRefSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *FrameRefnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *FrameRefnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if FrameReftrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type FrameRefSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []uint64 +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *FrameRefSet) ExportSortedSlices() *FrameRefSegmentDataSlices { + var sds FrameRefSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *FrameRefSet) ImportSortedSlices(sds *FrameRefSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *FrameRefSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.FileRange, uint64) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *FrameRefSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *FrameRefSet) saveRoot() *FrameRefSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *FrameRefSet) loadRoot(sds *FrameRefSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/fsutil_impl_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_impl_state_autogen.go new file mode 100644 index 000000000..9671c2349 --- /dev/null +++ b/pkg/sentry/fs/fsutil/fsutil_impl_state_autogen.go @@ -0,0 +1,331 @@ +// automatically generated by stateify. + +package fsutil + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *DirtySet) StateTypeName() string { + return "pkg/sentry/fs/fsutil.DirtySet" +} + +func (s *DirtySet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *DirtySet) beforeSave() {} + +// +checklocksignore +func (s *DirtySet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *DirtySegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *DirtySet) afterLoad() {} + +// +checklocksignore +func (s *DirtySet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*DirtySegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*DirtySegmentDataSlices)) }) +} + +func (n *Dirtynode) StateTypeName() string { + return "pkg/sentry/fs/fsutil.Dirtynode" +} + +func (n *Dirtynode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *Dirtynode) beforeSave() {} + +// +checklocksignore +func (n *Dirtynode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *Dirtynode) afterLoad() {} + +// +checklocksignore +func (n *Dirtynode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (d *DirtySegmentDataSlices) StateTypeName() string { + return "pkg/sentry/fs/fsutil.DirtySegmentDataSlices" +} + +func (d *DirtySegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (d *DirtySegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (d *DirtySegmentDataSlices) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.Start) + stateSinkObject.Save(1, &d.End) + stateSinkObject.Save(2, &d.Values) +} + +func (d *DirtySegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (d *DirtySegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.Start) + stateSourceObject.Load(1, &d.End) + stateSourceObject.Load(2, &d.Values) +} + +func (s *FileRangeSet) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FileRangeSet" +} + +func (s *FileRangeSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *FileRangeSet) beforeSave() {} + +// +checklocksignore +func (s *FileRangeSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *FileRangeSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *FileRangeSet) afterLoad() {} + +// +checklocksignore +func (s *FileRangeSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*FileRangeSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*FileRangeSegmentDataSlices)) }) +} + +func (n *FileRangenode) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FileRangenode" +} + +func (n *FileRangenode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *FileRangenode) beforeSave() {} + +// +checklocksignore +func (n *FileRangenode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *FileRangenode) afterLoad() {} + +// +checklocksignore +func (n *FileRangenode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (f *FileRangeSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FileRangeSegmentDataSlices" +} + +func (f *FileRangeSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (f *FileRangeSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (f *FileRangeSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.Start) + stateSinkObject.Save(1, &f.End) + stateSinkObject.Save(2, &f.Values) +} + +func (f *FileRangeSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (f *FileRangeSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.Start) + stateSourceObject.Load(1, &f.End) + stateSourceObject.Load(2, &f.Values) +} + +func (s *FrameRefSet) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FrameRefSet" +} + +func (s *FrameRefSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *FrameRefSet) beforeSave() {} + +// +checklocksignore +func (s *FrameRefSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *FrameRefSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *FrameRefSet) afterLoad() {} + +// +checklocksignore +func (s *FrameRefSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*FrameRefSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*FrameRefSegmentDataSlices)) }) +} + +func (n *FrameRefnode) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FrameRefnode" +} + +func (n *FrameRefnode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *FrameRefnode) beforeSave() {} + +// +checklocksignore +func (n *FrameRefnode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *FrameRefnode) afterLoad() {} + +// +checklocksignore +func (n *FrameRefnode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (f *FrameRefSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FrameRefSegmentDataSlices" +} + +func (f *FrameRefSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (f *FrameRefSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (f *FrameRefSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.Start) + stateSinkObject.Save(1, &f.End) + stateSinkObject.Save(2, &f.Values) +} + +func (f *FrameRefSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (f *FrameRefSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.Start) + stateSourceObject.Load(1, &f.End) + stateSourceObject.Load(2, &f.Values) +} + +func init() { + state.Register((*DirtySet)(nil)) + state.Register((*Dirtynode)(nil)) + state.Register((*DirtySegmentDataSlices)(nil)) + state.Register((*FileRangeSet)(nil)) + state.Register((*FileRangenode)(nil)) + state.Register((*FileRangeSegmentDataSlices)(nil)) + state.Register((*FrameRefSet)(nil)) + state.Register((*FrameRefnode)(nil)) + state.Register((*FrameRefSegmentDataSlices)(nil)) +} diff --git a/pkg/sentry/fs/fsutil/fsutil_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go new file mode 100644 index 000000000..748eb9272 --- /dev/null +++ b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go @@ -0,0 +1,411 @@ +// automatically generated by stateify. + +package fsutil + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (d *DirtyInfo) StateTypeName() string { + return "pkg/sentry/fs/fsutil.DirtyInfo" +} + +func (d *DirtyInfo) StateFields() []string { + return []string{ + "Keep", + } +} + +func (d *DirtyInfo) beforeSave() {} + +// +checklocksignore +func (d *DirtyInfo) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.Keep) +} + +func (d *DirtyInfo) afterLoad() {} + +// +checklocksignore +func (d *DirtyInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.Keep) +} + +func (sdfo *StaticDirFileOperations) StateTypeName() string { + return "pkg/sentry/fs/fsutil.StaticDirFileOperations" +} + +func (sdfo *StaticDirFileOperations) StateFields() []string { + return []string{ + "dentryMap", + "dirCursor", + } +} + +func (sdfo *StaticDirFileOperations) beforeSave() {} + +// +checklocksignore +func (sdfo *StaticDirFileOperations) StateSave(stateSinkObject state.Sink) { + sdfo.beforeSave() + stateSinkObject.Save(0, &sdfo.dentryMap) + stateSinkObject.Save(1, &sdfo.dirCursor) +} + +func (sdfo *StaticDirFileOperations) afterLoad() {} + +// +checklocksignore +func (sdfo *StaticDirFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sdfo.dentryMap) + stateSourceObject.Load(1, &sdfo.dirCursor) +} + +func (n *NoReadWriteFile) StateTypeName() string { + return "pkg/sentry/fs/fsutil.NoReadWriteFile" +} + +func (n *NoReadWriteFile) StateFields() []string { + return []string{} +} + +func (n *NoReadWriteFile) beforeSave() {} + +// +checklocksignore +func (n *NoReadWriteFile) StateSave(stateSinkObject state.Sink) { + n.beforeSave() +} + +func (n *NoReadWriteFile) afterLoad() {} + +// +checklocksignore +func (n *NoReadWriteFile) StateLoad(stateSourceObject state.Source) { +} + +func (scr *FileStaticContentReader) StateTypeName() string { + return "pkg/sentry/fs/fsutil.FileStaticContentReader" +} + +func (scr *FileStaticContentReader) StateFields() []string { + return []string{ + "content", + } +} + +func (scr *FileStaticContentReader) beforeSave() {} + +// +checklocksignore +func (scr *FileStaticContentReader) StateSave(stateSinkObject state.Sink) { + scr.beforeSave() + stateSinkObject.Save(0, &scr.content) +} + +func (scr *FileStaticContentReader) afterLoad() {} + +// +checklocksignore +func (scr *FileStaticContentReader) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &scr.content) +} + +func (f *HostFileMapper) StateTypeName() string { + return "pkg/sentry/fs/fsutil.HostFileMapper" +} + +func (f *HostFileMapper) StateFields() []string { + return []string{ + "refs", + } +} + +func (f *HostFileMapper) beforeSave() {} + +// +checklocksignore +func (f *HostFileMapper) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.refs) +} + +// +checklocksignore +func (f *HostFileMapper) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.refs) + stateSourceObject.AfterLoad(f.afterLoad) +} + +func (h *HostMappable) StateTypeName() string { + return "pkg/sentry/fs/fsutil.HostMappable" +} + +func (h *HostMappable) StateFields() []string { + return []string{ + "hostFileMapper", + "backingFile", + "mappings", + } +} + +func (h *HostMappable) beforeSave() {} + +// +checklocksignore +func (h *HostMappable) StateSave(stateSinkObject state.Sink) { + h.beforeSave() + stateSinkObject.Save(0, &h.hostFileMapper) + stateSinkObject.Save(1, &h.backingFile) + stateSinkObject.Save(2, &h.mappings) +} + +func (h *HostMappable) afterLoad() {} + +// +checklocksignore +func (h *HostMappable) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &h.hostFileMapper) + stateSourceObject.Load(1, &h.backingFile) + stateSourceObject.Load(2, &h.mappings) +} + +func (s *SimpleFileInode) StateTypeName() string { + return "pkg/sentry/fs/fsutil.SimpleFileInode" +} + +func (s *SimpleFileInode) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (s *SimpleFileInode) beforeSave() {} + +// +checklocksignore +func (s *SimpleFileInode) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeSimpleAttributes) +} + +func (s *SimpleFileInode) afterLoad() {} + +// +checklocksignore +func (s *SimpleFileInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeSimpleAttributes) +} + +func (n *NoReadWriteFileInode) StateTypeName() string { + return "pkg/sentry/fs/fsutil.NoReadWriteFileInode" +} + +func (n *NoReadWriteFileInode) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + } +} + +func (n *NoReadWriteFileInode) beforeSave() {} + +// +checklocksignore +func (n *NoReadWriteFileInode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.InodeSimpleAttributes) +} + +func (n *NoReadWriteFileInode) afterLoad() {} + +// +checklocksignore +func (n *NoReadWriteFileInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.InodeSimpleAttributes) +} + +func (i *InodeSimpleAttributes) StateTypeName() string { + return "pkg/sentry/fs/fsutil.InodeSimpleAttributes" +} + +func (i *InodeSimpleAttributes) StateFields() []string { + return []string{ + "fsType", + "unstable", + } +} + +func (i *InodeSimpleAttributes) beforeSave() {} + +// +checklocksignore +func (i *InodeSimpleAttributes) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.fsType) + stateSinkObject.Save(1, &i.unstable) +} + +func (i *InodeSimpleAttributes) afterLoad() {} + +// +checklocksignore +func (i *InodeSimpleAttributes) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.fsType) + stateSourceObject.Load(1, &i.unstable) +} + +func (i *InodeSimpleExtendedAttributes) StateTypeName() string { + return "pkg/sentry/fs/fsutil.InodeSimpleExtendedAttributes" +} + +func (i *InodeSimpleExtendedAttributes) StateFields() []string { + return []string{ + "xattrs", + } +} + +func (i *InodeSimpleExtendedAttributes) beforeSave() {} + +// +checklocksignore +func (i *InodeSimpleExtendedAttributes) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.xattrs) +} + +func (i *InodeSimpleExtendedAttributes) afterLoad() {} + +// +checklocksignore +func (i *InodeSimpleExtendedAttributes) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.xattrs) +} + +func (s *staticFile) StateTypeName() string { + return "pkg/sentry/fs/fsutil.staticFile" +} + +func (s *staticFile) StateFields() []string { + return []string{ + "FileStaticContentReader", + } +} + +func (s *staticFile) beforeSave() {} + +// +checklocksignore +func (s *staticFile) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.FileStaticContentReader) +} + +func (s *staticFile) afterLoad() {} + +// +checklocksignore +func (s *staticFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.FileStaticContentReader) +} + +func (i *InodeStaticFileGetter) StateTypeName() string { + return "pkg/sentry/fs/fsutil.InodeStaticFileGetter" +} + +func (i *InodeStaticFileGetter) StateFields() []string { + return []string{ + "Contents", + } +} + +func (i *InodeStaticFileGetter) beforeSave() {} + +// +checklocksignore +func (i *InodeStaticFileGetter) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.Contents) +} + +func (i *InodeStaticFileGetter) afterLoad() {} + +// +checklocksignore +func (i *InodeStaticFileGetter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.Contents) +} + +func (c *CachingInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/fsutil.CachingInodeOperations" +} + +func (c *CachingInodeOperations) StateFields() []string { + return []string{ + "backingFile", + "mfp", + "opts", + "attr", + "dirtyAttr", + "mappings", + "cache", + "dirty", + "hostFileMapper", + "refs", + } +} + +func (c *CachingInodeOperations) beforeSave() {} + +// +checklocksignore +func (c *CachingInodeOperations) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.backingFile) + stateSinkObject.Save(1, &c.mfp) + stateSinkObject.Save(2, &c.opts) + stateSinkObject.Save(3, &c.attr) + stateSinkObject.Save(4, &c.dirtyAttr) + stateSinkObject.Save(5, &c.mappings) + stateSinkObject.Save(6, &c.cache) + stateSinkObject.Save(7, &c.dirty) + stateSinkObject.Save(8, &c.hostFileMapper) + stateSinkObject.Save(9, &c.refs) +} + +func (c *CachingInodeOperations) afterLoad() {} + +// +checklocksignore +func (c *CachingInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.backingFile) + stateSourceObject.Load(1, &c.mfp) + stateSourceObject.Load(2, &c.opts) + stateSourceObject.Load(3, &c.attr) + stateSourceObject.Load(4, &c.dirtyAttr) + stateSourceObject.Load(5, &c.mappings) + stateSourceObject.Load(6, &c.cache) + stateSourceObject.Load(7, &c.dirty) + stateSourceObject.Load(8, &c.hostFileMapper) + stateSourceObject.Load(9, &c.refs) +} + +func (c *CachingInodeOperationsOptions) StateTypeName() string { + return "pkg/sentry/fs/fsutil.CachingInodeOperationsOptions" +} + +func (c *CachingInodeOperationsOptions) StateFields() []string { + return []string{ + "ForcePageCache", + "LimitHostFDTranslation", + } +} + +func (c *CachingInodeOperationsOptions) beforeSave() {} + +// +checklocksignore +func (c *CachingInodeOperationsOptions) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.ForcePageCache) + stateSinkObject.Save(1, &c.LimitHostFDTranslation) +} + +func (c *CachingInodeOperationsOptions) afterLoad() {} + +// +checklocksignore +func (c *CachingInodeOperationsOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.ForcePageCache) + stateSourceObject.Load(1, &c.LimitHostFDTranslation) +} + +func init() { + state.Register((*DirtyInfo)(nil)) + state.Register((*StaticDirFileOperations)(nil)) + state.Register((*NoReadWriteFile)(nil)) + state.Register((*FileStaticContentReader)(nil)) + state.Register((*HostFileMapper)(nil)) + state.Register((*HostMappable)(nil)) + state.Register((*SimpleFileInode)(nil)) + state.Register((*NoReadWriteFileInode)(nil)) + state.Register((*InodeSimpleAttributes)(nil)) + state.Register((*InodeSimpleExtendedAttributes)(nil)) + state.Register((*staticFile)(nil)) + state.Register((*InodeStaticFileGetter)(nil)) + state.Register((*CachingInodeOperations)(nil)) + state.Register((*CachingInodeOperationsOptions)(nil)) +} diff --git a/pkg/sentry/fs/fsutil/fsutil_unsafe_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_unsafe_state_autogen.go new file mode 100644 index 000000000..00b0994f6 --- /dev/null +++ b/pkg/sentry/fs/fsutil/fsutil_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fsutil diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go deleted file mode 100644 index 25e76d9f2..000000000 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsutil - -import ( - "bytes" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/usermem" -) - -type noopBackingFile struct{} - -func (noopBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - return dsts.NumBytes(), nil -} - -func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - return srcs.NumBytes(), nil -} - -func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { - return nil -} - -func (noopBackingFile) Sync(context.Context) error { - return nil -} - -func (noopBackingFile) FD() int { - return -1 -} - -func (noopBackingFile) Allocate(ctx context.Context, offset int64, length int64) error { - return nil -} - -func TestSetPermissions(t *testing.T) { - ctx := contexttest.Context(t) - - uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ - Perms: fs.FilePermsFromMode(0444), - }) - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - perms := fs.FilePermsFromMode(0777) - if !iops.SetPermissions(ctx, nil, perms) { - t.Fatalf("SetPermissions failed, want success") - } - - // Did permissions change? - if iops.attr.Perms != perms { - t.Fatalf("got perms +%v, want +%v", iops.attr.Perms, perms) - } - - // Did status change time change? - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("got status change time not dirty, want dirty") - } - if iops.attr.StatusChangeTime.Equal(uattr.StatusChangeTime) { - t.Fatalf("got status change time unchanged") - } -} - -func TestSetTimestamps(t *testing.T) { - ctx := contexttest.Context(t) - for _, test := range []struct { - desc string - ts fs.TimeSpec - wantChanged fs.AttrMask - }{ - { - desc: "noop", - ts: fs.TimeSpec{ - ATimeOmit: true, - MTimeOmit: true, - }, - wantChanged: fs.AttrMask{}, - }, - { - desc: "access time only", - ts: fs.TimeSpec{ - ATime: ktime.NowFromContext(ctx), - MTimeOmit: true, - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - }, - }, - { - desc: "modification time only", - ts: fs.TimeSpec{ - ATimeOmit: true, - MTime: ktime.NowFromContext(ctx), - }, - wantChanged: fs.AttrMask{ - ModificationTime: true, - }, - }, - { - desc: "access and modification time", - ts: fs.TimeSpec{ - ATime: ktime.NowFromContext(ctx), - MTime: ktime.NowFromContext(ctx), - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - ModificationTime: true, - }, - }, - { - desc: "system time access and modification time", - ts: fs.TimeSpec{ - ATimeSetSystemTime: true, - MTimeSetSystemTime: true, - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - ModificationTime: true, - }, - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - - epoch := ktime.ZeroTime - uattr := fs.UnstableAttr{ - AccessTime: epoch, - ModificationTime: epoch, - StatusChangeTime: epoch, - } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil { - t.Fatalf("SetTimestamps got error %v, want nil", err) - } - if test.wantChanged.AccessTime { - if !iops.attr.AccessTime.After(uattr.AccessTime) { - t.Fatalf("diritied access time did not advance, want %v > %v", iops.attr.AccessTime, uattr.AccessTime) - } - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("dirty access time requires dirty status change time") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not advance") - } - } - if test.wantChanged.ModificationTime { - if !iops.attr.ModificationTime.After(uattr.ModificationTime) { - t.Fatalf("diritied modification time did not advance") - } - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("dirty modification time requires dirty status change time") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not advance") - } - } - }) - } -} - -func TestTruncate(t *testing.T) { - ctx := contexttest.Context(t) - - uattr := fs.UnstableAttr{ - Size: 0, - } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - if err := iops.Truncate(ctx, nil, uattr.Size); err != nil { - t.Fatalf("Truncate got error %v, want nil", err) - } - var size int64 = 4096 - if err := iops.Truncate(ctx, nil, size); err != nil { - t.Fatalf("Truncate got error %v, want nil", err) - } - if iops.attr.Size != size { - t.Fatalf("Truncate got %d, want %d", iops.attr.Size, size) - } - if !iops.dirtyAttr.ModificationTime || !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("Truncate did not dirty modification and status change time") - } - if !iops.attr.ModificationTime.After(uattr.ModificationTime) { - t.Fatalf("dirtied modification time did not change") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not change") - } -} - -type sliceBackingFile struct { - data []byte -} - -func newSliceBackingFile(data []byte) *sliceBackingFile { - return &sliceBackingFile{data} -} - -func (f *sliceBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - r := safemem.BlockSeqReader{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)} - return r.ReadToBlocks(dsts) -} - -func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - w := safemem.BlockSeqWriter{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)} - return w.WriteFromBlocks(srcs) -} - -func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { - return nil -} - -func (*sliceBackingFile) Sync(context.Context) error { - return nil -} - -func (*sliceBackingFile) FD() int { - return -1 -} - -func (f *sliceBackingFile) Allocate(ctx context.Context, offset int64, length int64) error { - return linuxerr.EOPNOTSUPP -} - -type noopMappingSpace struct{} - -// Invalidate implements memmap.MappingSpace.Invalidate. -func (noopMappingSpace) Invalidate(ar hostarch.AddrRange, opts memmap.InvalidateOpts) { -} - -func anonInode(ctx context.Context) *fs.Inode { - return fs.NewInode(ctx, &SimpleFileInode{ - InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, - }, 0), - }, fs.NewPseudoMountSource(ctx), fs.StableAttr{ - Type: fs.Anonymous, - BlockSize: hostarch.PageSize, - }) -} - -func pagesOf(bs ...byte) []byte { - buf := make([]byte, 0, len(bs)*hostarch.PageSize) - for _, b := range bs { - buf = append(buf, bytes.Repeat([]byte{b}, hostarch.PageSize)...) - } - return buf -} - -func TestRead(t *testing.T) { - ctx := contexttest.Context(t) - - // Construct a 3-page file. - buf := pagesOf('a', 'b', 'c') - file := fs.NewFile(ctx, fs.NewDirent(ctx, anonInode(ctx), "anon"), fs.FileFlags{}, nil) - uattr := fs.UnstableAttr{ - Size: int64(len(buf)), - } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - // Expect the cache to be initially empty. - if cached := iops.cache.Span(); cached != 0 { - t.Errorf("Span got %d, want 0", cached) - } - - // Create a memory mapping of the second page (as CachingInodeOperations - // expects to only cache mapped pages), then call Translate to force it to - // be cached. - var ms noopMappingSpace - ar := hostarch.AddrRange{hostarch.PageSize, 2 * hostarch.PageSize} - if err := iops.AddMapping(ctx, ms, ar, hostarch.PageSize, true); err != nil { - t.Fatalf("AddMapping got %v, want nil", err) - } - mr := memmap.MappableRange{hostarch.PageSize, 2 * hostarch.PageSize} - if _, err := iops.Translate(ctx, mr, mr, hostarch.Read); err != nil { - t.Fatalf("Translate got %v, want nil", err) - } - if cached := iops.cache.Span(); cached != hostarch.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, hostarch.PageSize) - } - - // Try to read 4 pages. The first and third pages should be read directly - // from the "file", the second page should be read from the cache, and only - // 3 pages (the size of the file) should be readable. - rbuf := make([]byte, 4*hostarch.PageSize) - dst := usermem.BytesIOSequence(rbuf) - n, err := iops.Read(ctx, file, dst, 0) - if n != 3*hostarch.PageSize || (err != nil && err != io.EOF) { - t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*hostarch.PageSize) - } - rbuf = rbuf[:3*hostarch.PageSize] - - // Did we get the bytes we expect? - if !bytes.Equal(rbuf, buf) { - t.Errorf("Read back bytes %v, want %v", rbuf, buf) - } - - // Delete the memory mapping before iops.Release(). The cached page will - // either be evicted by ctx's pgalloc.MemoryFile, or dropped by - // iops.Release(). - iops.RemoveMapping(ctx, ms, ar, hostarch.PageSize, true) -} - -func TestWrite(t *testing.T) { - ctx := contexttest.Context(t) - - // Construct a 4-page file. - buf := pagesOf('a', 'b', 'c', 'd') - orig := append([]byte(nil), buf...) - inode := anonInode(ctx) - uattr := fs.UnstableAttr{ - Size: int64(len(buf)), - } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) - defer iops.Release() - - // Expect the cache to be initially empty. - if cached := iops.cache.Span(); cached != 0 { - t.Errorf("Span got %d, want 0", cached) - } - - // Create a memory mapping of the second and third pages (as - // CachingInodeOperations expects to only cache mapped pages), then call - // Translate to force them to be cached. - var ms noopMappingSpace - ar := hostarch.AddrRange{hostarch.PageSize, 3 * hostarch.PageSize} - if err := iops.AddMapping(ctx, ms, ar, hostarch.PageSize, true); err != nil { - t.Fatalf("AddMapping got %v, want nil", err) - } - defer iops.RemoveMapping(ctx, ms, ar, hostarch.PageSize, true) - mr := memmap.MappableRange{hostarch.PageSize, 3 * hostarch.PageSize} - if _, err := iops.Translate(ctx, mr, mr, hostarch.Read); err != nil { - t.Fatalf("Translate got %v, want nil", err) - } - if cached := iops.cache.Span(); cached != 2*hostarch.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, 2*hostarch.PageSize) - } - - // Write to the first 2 pages. - wbuf := pagesOf('e', 'f') - src := usermem.BytesIOSequence(wbuf) - n, err := iops.Write(ctx, src, 0) - if n != 2*hostarch.PageSize || err != nil { - t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*hostarch.PageSize) - } - - // The first page should have been written directly, since it was not cached. - want := append([]byte(nil), orig...) - copy(want, pagesOf('e')) - if !bytes.Equal(buf, want) { - t.Errorf("File contents are %v, want %v", buf, want) - } - - // Sync back to the "backing file". - if err := iops.WriteOut(ctx, inode); err != nil { - t.Errorf("Sync got %v, want nil", err) - } - - // Now the second page should have been written as well. - copy(want[hostarch.PageSize:], pagesOf('f')) - if !bytes.Equal(buf, want) { - t.Errorf("File contents are %v, want %v", buf, want) - } -} diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore deleted file mode 100644 index 2d19fc766..000000000 --- a/pkg/sentry/fs/g3doc/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.html diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md deleted file mode 100644 index 05e043583..000000000 --- a/pkg/sentry/fs/g3doc/fuse.md +++ /dev/null @@ -1,360 +0,0 @@ -# Foreword - -This document describes an on-going project to support FUSE filesystems within -the sentry. This is intended to become the final documentation for this -subsystem, and is therefore written in the past tense. However FUSE support is -currently incomplete and the document will be updated as things progress. - -# FUSE: Filesystem in Userspace - -The sentry supports dispatching filesystem operations to a FUSE server, allowing -FUSE filesystem to be used with a sandbox. - -## Overview - -FUSE has two main components: - -1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards - filesystem operations (usually initiated by syscalls) to the server. - -2. A server, which is a userspace daemon that implements the actual filesystem. - -The sentry implements the client component, which allows a server daemon running -within the sandbox to implement a filesystem within the sandbox. - -A FUSE filesystem is initialized with `mount(2)`, typically with the help of a -utility like `fusermount(1)`. Various mount options exist for establishing -ownership and access permissions on the filesystem, but the most important mount -option is a file descriptor used to establish communication between the client -and server. - -The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation, -the client and server use the FUSE protocol described in `fuse(4)` to service -filesystem operations. See the "Protocol" section below for more information -about this protocol. The core of the sentry support for FUSE is the client-side -implementation of this protocol. - -## FUSE in the Sentry - -The sentry's FUSE client targets VFS2 and has the following components: - -- An implementation of `/dev/fuse`. - -- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting - VFS2, one point of contention may be the lack of inodes in VFS2. We can - tentatively implement a kernfs-based filesystem to bridge the gap in APIs. - The kernfs base functionality can serve the role of the Linux inode cache - and, the filesystem can map VFS2 syscalls to kernfs inode operations; see - the `kernfs.Inode` interface. - -The FUSE protocol lends itself well to marshaling with `go_marshal`. The various -request and response packets can be defined in the ABI package and converted to -and from the wire format using `go_marshal`. - -### Design Goals - -- While filesystem performance is always important, the sentry's FUSE support - is primarily concerned with compatibility, with performance as a secondary - concern. - -- Avoiding deadlocks from a hung server daemon. - -- Consider the potential for denial of service from a malicious server daemon. - Protecting itself from userspace is already a design goal for the sentry, - but needs additional consideration for FUSE. Normally, an operating system - doesn't rely on userspace to make progress with filesystem operations. Since - this changes with FUSE, it opens up the possibility of creating a chain of - dependencies controlled by userspace, which could affect an entire sandbox. - For example: a FUSE op can block a syscall, which could be holding a - subsystem lock, which can then block another task goroutine. - -### Milestones - -Below are some broad goals to aim for while implementing FUSE in the sentry. -Many FUSE ops can be grouped into broad categories of functionality, and most -ops can be implemented in parallel. - -#### Minimal client that can mount a trivial FUSE filesystem. - -- Implement `/dev/fuse` - a character device used to establish an FD for - communication between the sentry and the server daemon. - -- Implement basic FUSE ops like `FUSE_INIT`. - -#### Read-only mount with basic file operations - -- Implement the majority of file, directory and file descriptor FUSE ops. For - this milestone, we can skip uncommon or complex operations like mmap, mknod, - file locking, poll, and extended attributes. We can stub these out along - with any ops that modify the filesystem. The exact list of required ops are - to be determined, but the goal is to mount a real filesystem as read-only, - and be able to read contents from the filesystem in the sentry. - -#### Full read-write support - -- Implement the remaining FUSE ops and decide if we can omit rarely used - operations like ioctl. - -### Design Details - -#### Lifecycle for a FUSE Request - -- User invokes a syscall -- Sentry prepares corresponding request - - If FUSE device is available - - Write the request in binary - - If FUSE device is full - - Kernel task blocked until available -- Sentry notifies the readers of fuse device that it's ready for read -- FUSE daemon reads the request and processes it -- Sentry waits until a reply is written to the FUSE device - - but returns directly for async requests -- FUSE daemon writes to the fuse device -- Sentry processes the reply - - For sync requests, unblock blocked kernel task - - For async requests, execute pre-specified callback if any -- Sentry returns the syscall to the user - -#### Channels and Queues for Requests in Different Stages - -`connection.initializedChan` - -- a channel that the requests issued before connection initialization blocks - on. - -`fd.queue` - -- a queue of requests that haven’t been read by the FUSE daemon yet. - -`fd.completions` - -- a map of the requests that have been prepared but not yet received a - response, including the ones on the `fd.queue`. - -`fd.waitQueue` - -- a queue of waiters that is waiting for the fuse device fd to be available, - such as the FUSE daemon. - -`fd.fullQueueCh` - -- a channel that the kernel task will be blocked on when the fd is not - available. - -#### Basic I/O Implementation - -Currently we have implemented basic functionalities of read and write for our -FUSE. We describe the design and ways to improve it here: - -##### Basic FUSE Read - -The vfs2 expects implementations of `vfs.FileDescriptionImpl.Read()` and -`vfs.FileDescriptionImpl.PRead()`. When a syscall is made, it will eventually -reach our implementation of those interface functions located at -`pkg/sentry/fsimpl/fuse/regular_file.go` for regular files. - -After validation checks of the input, sentry sends `FUSE_READ` requests to the -FUSE daemon. The FUSE daemon returns data after the `fuse_out_header` as the -responses. For the first version, we create a copy in kernel memory of those -data. They are represented as a byte slice in the marshalled struct. This -happens as a common process for all the FUSE responses at this moment at -`pkg/sentry/fsimpl/fuse/dev.go:writeLocked()`. We then directly copy from this -intermediate buffer to the input buffer provided by the read syscall. - -There is an extra requirement for FUSE: When mounting the FUSE fs, the mounter -or the FUSE daemon can specify a `max_read` or a `max_pages` parameter. They are -the upperbound of the bytes to read in each `FUSE_READ` request. We implemented -the code to handle the fragmented reads. - -To improve the performance: ideally we should have buffer cache to copy those -data from the responses of FUSE daemon into, as is also the design of several -other existing file system implementations for sentry, instead of a single-use -temporary buffer. Directly mapping the memory of one process to another could -also boost the performance, but to keep them isolated, we did not choose to do -so. - -##### Basic FUSE Write - -The vfs2 invokes implementations of `vfs.FileDescriptionImpl.Write()` and -`vfs.FileDescriptionImpl.PWrite()` on the regular file descriptor of FUSE when a -user makes write(2) and pwrite(2) syscall. - -For valid writes, sentry sends the bytes to write after a `FUSE_WRITE` header -(can be regarded as a request with 2 payloads) to the FUSE daemon. For the first -version, we allocate a buffer inside kernel memory to store the bytes from the -user, and copy directly from that buffer to the memory of FUSE daemon. This -happens at `pkg/sentry/fsimpl/fuse/dev.go:readLocked()` - -The parameters `max_write` and `max_pages` restrict the number of bytes in one -`FUSE_WRITE`. There are code handling fragmented writes in current -implementation. - -To have better performance: the extra copy created to store the bytes to write -can be replaced by the buffer cache as well. - -# Appendix - -## FUSE Protocol - -The FUSE protocol is a request-response protocol. All requests are initiated by -the client. The wire-format for the protocol is raw C structs serialized to -memory. - -All FUSE requests begin with the following request header: - -```c -struct fuse_in_header { - uint32_t len; // Length of the request, including this header. - uint32_t opcode; // Requested operation. - uint64_t unique; // A unique identifier for this request. - uint64_t nodeid; // ID of the filesystem object being operated on. - uint32_t uid; // UID of the requesting process. - uint32_t gid; // GID of the requesting process. - uint32_t pid; // PID of the requesting process. - uint32_t padding; -}; -``` - -The request is then followed by a payload specific to the `opcode`. - -All responses begin with this response header: - -```c -struct fuse_out_header { - uint32_t len; // Length of the response, including this header. - int32_t error; // Status of the request, 0 if success. - uint64_t unique; // The unique identifier from the corresponding request. -}; -``` - -The response payload also depends on the request `opcode`. If `error != 0`, the -response payload must be empty. - -### Operations - -The following is a list of all FUSE operations used in `fuse_in_header.opcode` -as of Linux v4.4, and a brief description of their purpose. These are defined in -`uapi/linux/fuse.h`. Many of these have a corresponding request and response -payload struct; `fuse(4)` has details for some of these. We also note how these -operations map to the sentry virtual filesystem. - -#### FUSE meta-operations - -These operations are specific to FUSE and don't have a corresponding action in a -generic filesystem. - -- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the - first message sent by the client after mount. This is used for version and - feature negotiation. This is related to `mount(2)`. -- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`. -- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the - `fuse_in_header.unique` value provided in the corresponding request header. - The client can send at most one of these per request, and will enter an - uninterruptible wait for a reply. The server is expected to reply promptly. -- `FUSE_FORGET`: A hint to the server that server should evict the indicate - node from any caches. This is wired up to `(struct - super_operations).evict_inode` in Linux, which is in turned hooked as the - inode cache shrinker which is typically triggered by system memory pressure. -- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`. - -#### Filesystem Syscalls - -These FUSE ops map directly to an equivalent filesystem syscall, or family of -syscalls. The relevant syscalls have a similar name to the operation, unless -otherwise noted. - -Node creation: - -- `FUSE_MKNOD` -- `FUSE_MKDIR` -- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which - atomically creates and opens a node. - -Node attributes and extended attributes: - -- `FUSE_GETATTR` -- `FUSE_SETATTR` -- `FUSE_SETXATTR` -- `FUSE_GETXATTR` -- `FUSE_LISTXATTR` -- `FUSE_REMOVEXATTR` - -Node link manipulation: - -- `FUSE_READLINK` -- `FUSE_LINK` -- `FUSE_SYMLINK` -- `FUSE_UNLINK` - -Directory operations: - -- `FUSE_RMDIR` -- `FUSE_RENAME` -- `FUSE_RENAME2` -- `FUSE_OPENDIR`: `open(2)` for directories. -- `FUSE_RELEASEDIR`: `close(2)` for directories. -- `FUSE_READDIR` -- `FUSE_READDIRPLUS` -- `FUSE_FSYNCDIR`: `fsync(2)` for directories. -- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is - reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path - component to a node. However the returned identifier is opaque to the - client. The server must remember this mapping, as this is how the client - will reference the node in the future. - -File operations: - -- `FUSE_OPEN`: `open(2)` for files. -- `FUSE_RELEASE`: `close(2)` for files. -- `FUSE_FSYNC` -- `FUSE_FALLOCATE` -- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`. -- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`. - -File locking: - -- `FUSE_GETLK` -- `FUSE_SETLK` -- `FUSE_SETLKW` -- `FUSE_COPY_FILE_RANGE` - -File descriptor operations: - -- `FUSE_IOCTL` -- `FUSE_POLL` -- `FUSE_LSEEK` - -Filesystem operations: - -- `FUSE_STATFS` - -#### Permissions - -- `FUSE_ACCESS` is used to check if a node is accessible, as part of many - syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the - sentry. - -#### I/O Operations - -These ops are used to read and write file pages. They're used to implement both -I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`. - -- `FUSE_READ` -- `FUSE_WRITE` - -#### Miscellaneous - -- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is - closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)` - syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the - sentry. -- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed. -- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?] - -# References - -- [fuse(4) Linux manual page](https://www.man7.org/linux/man-pages/man4/fuse.4.html) -- [Linux kernel FUSE documentation](https://www.kernel.org/doc/html/latest/filesystems/fuse.html) -- [The reference implementation of the Linux FUSE (Filesystem in Userspace) - interface](https://github.com/libfuse/libfuse) -- [The kernel interface of FUSE](https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/fuse.h) diff --git a/pkg/sentry/fs/g3doc/inotify.md b/pkg/sentry/fs/g3doc/inotify.md deleted file mode 100644 index 85063d4e6..000000000 --- a/pkg/sentry/fs/g3doc/inotify.md +++ /dev/null @@ -1,122 +0,0 @@ -# Inotify - -Inotify implements the like-named filesystem event notification system for the -sentry, see `inotify(7)`. - -## Architecture - -For the most part, the sentry implementation of inotify mirrors the Linux -architecture. Inotify instances (i.e. the fd returned by inotify_init(2)) are -backed by a pseudo-filesystem. Events are generated from various places in the -sentry, including the [syscall layer][syscall_dir], the [vfs layer][dirent] and -the [process fd table][fd_table]. Watches are stored in inodes and generated -events are queued to the inotify instance owning the watches for delivery to the -user. - -## Objects - -Here is a brief description of the existing and new objects involved in the -sentry inotify mechanism, and how they interact: - -### [`fs.Inotify`][inotify] - -- An inotify instances, created by inotify_init(2)/inotify_init1(2). -- The inotify fd has a `fs.Dirent`, supports filesystem syscalls to read - events. -- Has multiple `fs.Watch`es, with at most one watch per target inode, per - inotify instance. -- Has an instance `id` which is globally unique. This is *not* the fd number - for this instance, since the fd can be duped. This `id` is not externally - visible. - -### [`fs.Watch`][watch] - -- An inotify watch, created/deleted by - inotify_add_watch(2)/inotify_rm_watch(2). -- Owned by an `fs.Inotify` instance, each watch keeps a pointer to the - `owner`. -- Associated with a single `fs.Inode`, which is the watch `target`. While the - watch is active, it indirectly pins `target` to memory. See the "Reference - Model" section for a detailed explanation. -- Filesystem operations on `target` generate `fs.Event`s. - -### [`fs.Event`][event] - -- A simple struct encapsulating all the fields for an inotify event. -- Generated by `fs.Watch`es and forwarded to the watches' `owner`s. -- Serialized to the user during read(2) syscalls on the associated - `fs.Inotify`'s fd. - -### [`fs.Dirent`][dirent] - -- Many inotify events are generated inside dirent methods. Events are - generated in the dirent methods rather than `fs.Inode` methods because some - events carry the name of the subject node, and node names are generally - unavailable in an `fs.Inode`. -- Dirents do not directly contain state for any watches. Instead, they forward - notifications to the underlying `fs.Inode`. - -### [`fs.Inode`][inode] - -- Interacts with inotify through `fs.Watch`es. -- Inodes contain a map of all active `fs.Watch`es on them. -- An `fs.Inotify` instance can have at most one `fs.Watch` per inode. - `fs.Watch`es on an inode are indexed by their `owner`'s `id`. -- All inotify logic is encapsulated in the [`Watches`][inode_watches] struct - in an inode. Logically, `Watches` is the set of inotify watches on the - inode. - -## Reference Model - -The sentry inotify implementation has a complex reference model. An inotify -watch observes a single inode. For efficient lookup, the state for a watch is -stored directly on the target inode. This state needs to be persistent for the -lifetime of watch. Unlike usual filesystem metadata, the watch state has no -"on-disk" representation, so they cannot be reconstructed by the filesystem if -the inode is flushed from memory. This effectively means we need to keep any -inodes with actives watches pinned to memory. - -We can't just hold an extra ref on the inode to pin it to memory because some -filesystems (such as gofer-based filesystems) don't have persistent inodes. In -such a filesystem, if we just pin the inode, nothing prevents the enclosing -dirent from being GCed. Once the dirent is GCed, the pinned inode is -unreachable -- these filesystems generate a new inode by re-reading the node -state on the next walk. Incidentally, hardlinks also don't work on these -filesystems for this reason. - -To prevent the above scenario, when a new watch is added on an inode, we *pin* -the dirent we used to reach the inode. Note that due to hardlinks, this dirent -may not be the only dirent pointing to the inode. Attempting to set an inotify -watch via multiple hardlinks to the same file results in the same watch being -returned for both links. However, for each new dirent we use to reach the same -inode, we add a new pin. We need a new pin for each new dirent used to reach the -inode because we have no guarantees about the deletion order of the different -links to the inode. - -## Lock Ordering - -There are 4 locks related to the inotify implementation: - -- `Inotify.mu`: the inotify instance lock. -- `Inotify.evMu`: the inotify event queue lock. -- `Watch.mu`: the watch lock, used to protect pins. -- `fs.Watches.mu`: the inode watch set mu, used to protect the collection of - watches on the inode. - -The correct lock ordering for inotify code is: - -`Inotify.mu` -> `fs.Watches.mu` -> `Watch.mu` -> `Inotify.evMu`. - -We need a distinct lock for the event queue because by the time a goroutine -attempts to queue a new event, it is already holding `fs.Watches.mu`. If we used -`Inotify.mu` to also protect the event queue, this would violate the above lock -ordering. - -[dirent]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/dirent.go -[event]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_event.go -[fd_table]: https://github.com/google/gvisor/blob/master/pkg/sentry/kernel/fd_table.go -[inode]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode.go -[inode_watches]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode_inotify.go -[inotify]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify.go -[syscall_dir]: https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/ -[watch]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_watch.go diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD deleted file mode 100644 index ee2f287d9..000000000 --- a/pkg/sentry/fs/gofer/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gofer", - srcs = [ - "attr.go", - "cache_policy.go", - "context_file.go", - "device.go", - "fifo.go", - "file.go", - "file_state.go", - "fs.go", - "handles.go", - "inode.go", - "inode_state.go", - "path.go", - "session.go", - "session_state.go", - "socket.go", - "util.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/hostarch", - "//pkg/log", - "//pkg/metric", - "//pkg/p9", - "//pkg/refs", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fdpipe", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/host", - "//pkg/sentry/fsmetric", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sync", - "//pkg/syserr", - "//pkg/unet", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "gofer_test", - size = "small", - srcs = ["gofer_test.go"], - library = ":gofer", - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/p9", - "//pkg/p9/p9test", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/gofer/gofer_state_autogen.go b/pkg/sentry/fs/gofer/gofer_state_autogen.go new file mode 100644 index 000000000..e36638c37 --- /dev/null +++ b/pkg/sentry/fs/gofer/gofer_state_autogen.go @@ -0,0 +1,272 @@ +// automatically generated by stateify. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *fifo) StateTypeName() string { + return "pkg/sentry/fs/gofer.fifo" +} + +func (i *fifo) StateFields() []string { + return []string{ + "InodeOperations", + "fileIops", + } +} + +func (i *fifo) beforeSave() {} + +// +checklocksignore +func (i *fifo) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeOperations) + stateSinkObject.Save(1, &i.fileIops) +} + +func (i *fifo) afterLoad() {} + +// +checklocksignore +func (i *fifo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeOperations) + stateSourceObject.Load(1, &i.fileIops) +} + +func (f *fileOperations) StateTypeName() string { + return "pkg/sentry/fs/gofer.fileOperations" +} + +func (f *fileOperations) StateFields() []string { + return []string{ + "inodeOperations", + "dirCursor", + "flags", + } +} + +func (f *fileOperations) beforeSave() {} + +// +checklocksignore +func (f *fileOperations) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.inodeOperations) + stateSinkObject.Save(1, &f.dirCursor) + stateSinkObject.Save(2, &f.flags) +} + +// +checklocksignore +func (f *fileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &f.inodeOperations) + stateSourceObject.Load(1, &f.dirCursor) + stateSourceObject.LoadWait(2, &f.flags) + stateSourceObject.AfterLoad(f.afterLoad) +} + +func (f *filesystem) StateTypeName() string { + return "pkg/sentry/fs/gofer.filesystem" +} + +func (f *filesystem) StateFields() []string { + return []string{} +} + +func (f *filesystem) beforeSave() {} + +// +checklocksignore +func (f *filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystem) afterLoad() {} + +// +checklocksignore +func (f *filesystem) StateLoad(stateSourceObject state.Source) { +} + +func (i *inodeOperations) StateTypeName() string { + return "pkg/sentry/fs/gofer.inodeOperations" +} + +func (i *inodeOperations) StateFields() []string { + return []string{ + "fileState", + "cachingInodeOps", + } +} + +func (i *inodeOperations) beforeSave() {} + +// +checklocksignore +func (i *inodeOperations) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.fileState) + stateSinkObject.Save(1, &i.cachingInodeOps) +} + +func (i *inodeOperations) afterLoad() {} + +// +checklocksignore +func (i *inodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &i.fileState) + stateSourceObject.Load(1, &i.cachingInodeOps) +} + +func (i *inodeFileState) StateTypeName() string { + return "pkg/sentry/fs/gofer.inodeFileState" +} + +func (i *inodeFileState) StateFields() []string { + return []string{ + "s", + "sattr", + "loading", + "savedUAttr", + "hostMappable", + } +} + +// +checklocksignore +func (i *inodeFileState) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + var loadingValue struct{} + loadingValue = i.saveLoading() + stateSinkObject.SaveValue(2, loadingValue) + stateSinkObject.Save(0, &i.s) + stateSinkObject.Save(1, &i.sattr) + stateSinkObject.Save(3, &i.savedUAttr) + stateSinkObject.Save(4, &i.hostMappable) +} + +// +checklocksignore +func (i *inodeFileState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &i.s) + stateSourceObject.LoadWait(1, &i.sattr) + stateSourceObject.Load(3, &i.savedUAttr) + stateSourceObject.Load(4, &i.hostMappable) + stateSourceObject.LoadValue(2, new(struct{}), func(y interface{}) { i.loadLoading(y.(struct{})) }) + stateSourceObject.AfterLoad(i.afterLoad) +} + +func (l *overrideInfo) StateTypeName() string { + return "pkg/sentry/fs/gofer.overrideInfo" +} + +func (l *overrideInfo) StateFields() []string { + return []string{ + "dirent", + "endpoint", + "inode", + } +} + +func (l *overrideInfo) beforeSave() {} + +// +checklocksignore +func (l *overrideInfo) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.dirent) + stateSinkObject.Save(1, &l.endpoint) + stateSinkObject.Save(2, &l.inode) +} + +func (l *overrideInfo) afterLoad() {} + +// +checklocksignore +func (l *overrideInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.dirent) + stateSourceObject.Load(1, &l.endpoint) + stateSourceObject.Load(2, &l.inode) +} + +func (e *overrideMaps) StateTypeName() string { + return "pkg/sentry/fs/gofer.overrideMaps" +} + +func (e *overrideMaps) StateFields() []string { + return []string{ + "pathMap", + } +} + +func (e *overrideMaps) beforeSave() {} + +// +checklocksignore +func (e *overrideMaps) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.pathMap) +} + +func (e *overrideMaps) afterLoad() {} + +// +checklocksignore +func (e *overrideMaps) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.pathMap) +} + +func (s *session) StateTypeName() string { + return "pkg/sentry/fs/gofer.session" +} + +func (s *session) StateFields() []string { + return []string{ + "AtomicRefCount", + "msize", + "version", + "cachePolicy", + "aname", + "superBlockFlags", + "limitHostFDTranslation", + "overlayfsStaleRead", + "connID", + "inodeMappings", + "mounter", + "overrides", + } +} + +// +checklocksignore +func (s *session) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.AtomicRefCount) + stateSinkObject.Save(1, &s.msize) + stateSinkObject.Save(2, &s.version) + stateSinkObject.Save(3, &s.cachePolicy) + stateSinkObject.Save(4, &s.aname) + stateSinkObject.Save(5, &s.superBlockFlags) + stateSinkObject.Save(6, &s.limitHostFDTranslation) + stateSinkObject.Save(7, &s.overlayfsStaleRead) + stateSinkObject.Save(8, &s.connID) + stateSinkObject.Save(9, &s.inodeMappings) + stateSinkObject.Save(10, &s.mounter) + stateSinkObject.Save(11, &s.overrides) +} + +// +checklocksignore +func (s *session) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.AtomicRefCount) + stateSourceObject.LoadWait(1, &s.msize) + stateSourceObject.LoadWait(2, &s.version) + stateSourceObject.LoadWait(3, &s.cachePolicy) + stateSourceObject.LoadWait(4, &s.aname) + stateSourceObject.LoadWait(5, &s.superBlockFlags) + stateSourceObject.Load(6, &s.limitHostFDTranslation) + stateSourceObject.Load(7, &s.overlayfsStaleRead) + stateSourceObject.LoadWait(8, &s.connID) + stateSourceObject.LoadWait(9, &s.inodeMappings) + stateSourceObject.LoadWait(10, &s.mounter) + stateSourceObject.LoadWait(11, &s.overrides) + stateSourceObject.AfterLoad(s.afterLoad) +} + +func init() { + state.Register((*fifo)(nil)) + state.Register((*fileOperations)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*inodeOperations)(nil)) + state.Register((*inodeFileState)(nil)) + state.Register((*overrideInfo)(nil)) + state.Register((*overrideMaps)(nil)) + state.Register((*session)(nil)) +} diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go deleted file mode 100644 index 4924debeb..000000000 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "fmt" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/p9/p9test" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// rootTest runs a test with a p9 mock and an fs.InodeOperations created from -// the attached root directory. The root file will be closed and client -// disconnected, but additional files must be closed manually. -func rootTest(t *testing.T, name string, cp cachePolicy, fn func(context.Context, *p9test.Harness, *p9test.Mock, *fs.Inode)) { - t.Run(name, func(t *testing.T) { - h, c := p9test.NewHarness(t) - defer h.Finish() - - // Create a new root. Note that we pass an empty, but non-nil - // map here. This allows tests to extend the root children - // dynamically. - root := h.NewDirectory(map[string]p9test.Generator{})(nil) - - // Return this as the root. - h.Attacher.EXPECT().Attach().Return(root, nil).Times(1) - - // ... and open via the client. - rootFile, err := c.Attach("/") - if err != nil { - t.Fatalf("unable to attach: %v", err) - } - defer rootFile.Close() - - // Wrap an a session. - s := &session{ - mounter: fs.RootOwner, - cachePolicy: cp, - client: c, - } - - // ... and an INode, with only the mode being explicitly valid for now. - ctx := contexttest.Context(t) - sattr, rootInodeOperations := newInodeOperations(ctx, s, contextFile{ - file: rootFile, - }, root.QID, p9.AttrMaskAll(), root.Attr) - m := fs.NewMountSource(ctx, s, &filesystem{}, fs.MountSourceFlags{}) - rootInode := fs.NewInode(ctx, rootInodeOperations, m, sattr) - - // Ensure that the cache is fully invalidated, so that any - // close actions actually take place before the full harness is - // torn down. - defer func() { - m.FlushDirentRefs() - - // Wait for all resources to be released, otherwise the - // operations may fail after we close the rootFile. - fs.AsyncBarrier() - }() - - // Execute the test. - fn(ctx, h, root, rootInode) - }) -} - -func TestLookup(t *testing.T) { - type lookupTest struct { - // Name of the test. - name string - - // Expected return value. - want error - } - - tests := []lookupTest{ - { - name: "mock Walk passes (function succeeds)", - want: nil, - }, - { - name: "mock Walk fails (function fails)", - want: linuxerr.ENOENT, - }, - } - - const file = "file" // The walked target file. - - for _, test := range tests { - rootTest(t, test.name, cacheNone, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) { - // Setup the appropriate result. - rootFile.WalkCallback = func() error { - return test.want - } - if test.want == nil { - // Set the contents of the root. We expect a - // normal file generator for ppp above. This is - // overriden by setting WalkErr in the mock. - rootFile.AddChild(file, h.NewFile()) - } - - // Call function. - dirent, err := rootInode.Lookup(ctx, file) - - // Unwrap the InodeOperations. - var newInodeOperations fs.InodeOperations - if dirent != nil { - if dirent.IsNegative() { - err = linuxerr.ENOENT - } else { - newInodeOperations = dirent.Inode.InodeOperations - } - } - - // Check return values. - if err != test.want { - t.Logf("err: %v %T", err, err) - t.Errorf("Lookup got err %v, want %v", err, test.want) - } - if err == nil && newInodeOperations == nil { - t.Logf("err: %v %T", err, err) - t.Errorf("Lookup got non-nil err and non-nil node, wanted at least one non-nil") - } - }) - } -} - -func TestRevalidation(t *testing.T) { - type revalidationTest struct { - cachePolicy cachePolicy - - // Whether dirent should be reloaded before any modifications. - preModificationWantReload bool - - // Whether dirent should be reloaded after updating an unstable - // attribute on the remote fs. - postModificationWantReload bool - - // Whether dirent unstable attributes should be updated after - // updating an attribute on the remote fs. - postModificationWantUpdatedAttrs bool - - // Whether dirent should be reloaded after the remote has - // removed the file. - postRemovalWantReload bool - } - - tests := []revalidationTest{ - { - // Policy cacheNone causes Revalidate to always return - // true. - cachePolicy: cacheNone, - preModificationWantReload: true, - postModificationWantReload: true, - postModificationWantUpdatedAttrs: true, - postRemovalWantReload: true, - }, - { - // Policy cacheAll causes Revalidate to always return - // false. - cachePolicy: cacheAll, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: false, - postRemovalWantReload: false, - }, - { - // Policy cacheAllWritethrough causes Revalidate to - // always return false. - cachePolicy: cacheAllWritethrough, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: false, - postRemovalWantReload: false, - }, - { - // Policy cacheRemoteRevalidating causes Revalidate to - // return update cached unstable attrs, and returns - // true only when the remote inode itself has been - // removed or replaced. - cachePolicy: cacheRemoteRevalidating, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: true, - postRemovalWantReload: true, - }, - } - - const file = "file" // The file walked below. - - for _, test := range tests { - name := fmt.Sprintf("cachepolicy=%s", test.cachePolicy) - rootTest(t, name, test.cachePolicy, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) { - // Wrap in a dirent object. - rootDir := fs.NewDirent(ctx, rootInode, "root") - - // Create a mock file a child of the root. We save when - // this is generated, so that when the time changed, we - // can update the original entry. - var origMocks []*p9test.Mock - rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock { - // Regular a regular file that has a consistent - // path number. This might be used by - // validation so we don't change it. - m := h.NewMock(parent, 0, p9.Attr{ - Mode: p9.ModeRegular, - }) - origMocks = append(origMocks, m) - return m - }) - - // Do the walk. - dirent, err := rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - - // We must release the dirent, of the test will fail - // with a reference leak. This is tracked by p9test. - defer dirent.DecRef(ctx) - - // Walk again. Depending on the cache policy, we may - // get a new dirent. - newDirent, err := rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if test.preModificationWantReload && dirent == newDirent { - t.Errorf("Lookup with cachePolicy=%s got old dirent %+v, wanted a new dirent", test.cachePolicy, dirent) - } - if !test.preModificationWantReload && dirent != newDirent { - t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent) - } - newDirent.DecRef(ctx) // See above. - - // Modify the underlying mocked file's modification - // time for the next walk that occurs. - nowSeconds := time.Now().Unix() - rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock { - // Ensure that the path is the same as above, - // but we change only the modification time of - // the file. - return h.NewMock(parent, 0, p9.Attr{ - Mode: p9.ModeRegular, - MTimeSeconds: uint64(nowSeconds), - }) - }) - - // We also modify the original time, so that GetAttr - // behaves as expected for the caching case. - for _, m := range origMocks { - m.Attr.MTimeSeconds = uint64(nowSeconds) - } - - // Walk again. Depending on the cache policy, we may - // get a new dirent. - newDirent, err = rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if test.postModificationWantReload && dirent == newDirent { - t.Errorf("Lookup with cachePolicy=%s got old dirent, wanted a new dirent", test.cachePolicy) - } - if !test.postModificationWantReload && dirent != newDirent { - t.Errorf("Lookup with cachePolicy=%s got new dirent, wanted old dirent", test.cachePolicy) - } - uattrs, err := newDirent.Inode.UnstableAttr(ctx) - if err != nil { - t.Fatalf("Error getting unstable attrs: %v", err) - } - gotModTimeSeconds := uattrs.ModificationTime.Seconds() - if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds { - t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds) - } - newDirent.DecRef(ctx) // See above. - - // Remove the file from the remote fs, subsequent walks - // should now fail to find anything. - rootFile.RemoveChild(file) - - // Walk again. Depending on the cache policy, we may - // get ENOENT. - newDirent, err = rootDir.Walk(ctx, rootDir, file) - if test.postRemovalWantReload && err == nil { - t.Errorf("Lookup with cachePolicy=%s got nil error, wanted ENOENT", test.cachePolicy) - } - if !test.postRemovalWantReload && (err != nil || dirent != newDirent) { - t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err) - } - if err == nil { - newDirent.DecRef(ctx) // See above. - } - }) - } -} diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD deleted file mode 100644 index 921612e9c..000000000 --- a/pkg/sentry/fs/host/BUILD +++ /dev/null @@ -1,86 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "host", - srcs = [ - "control.go", - "descriptor.go", - "descriptor_state.go", - "device.go", - "file.go", - "host.go", - "inode.go", - "inode_state.go", - "ioctl_unsafe.go", - "socket.go", - "socket_iovec.go", - "socket_state.go", - "socket_unsafe.go", - "tty.go", - "util.go", - "util_amd64_unsafe.go", - "util_arm64_unsafe.go", - "util_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/marshal/primitive", - "//pkg/refs", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/hostfd", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sentry/uniqueid", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/unet", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "host_test", - size = "small", - srcs = [ - "descriptor_test.go", - "inode_test.go", - "socket_test.go", - "wait_test.go", - ], - library = ":host", - deps = [ - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/sentry/contexttest", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go deleted file mode 100644 index cb809ab2d..000000000 --- a/pkg/sentry/fs/host/descriptor_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "io/ioutil" - "path/filepath" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestDescriptorRelease(t *testing.T) { - for _, tc := range []struct { - name string - saveable bool - wouldBlock bool - }{ - {name: "all false"}, - {name: "saveable", saveable: true}, - {name: "wouldBlock", wouldBlock: true}, - } { - t.Run(tc.name, func(t *testing.T) { - dir, err := ioutil.TempDir("", "descriptor_test") - if err != nil { - t.Fatal("ioutil.TempDir() failed:", err) - } - - fd, err := unix.Open(filepath.Join(dir, "file"), unix.O_RDWR|unix.O_CREAT, 0666) - if err != nil { - t.Fatal("failed to open temp file:", err) - } - - // FD ownership is transferred to the descritor. - queue := &waiter.Queue{} - d, err := newDescriptor(fd, tc.saveable, tc.wouldBlock, queue) - if err != nil { - unix.Close(fd) - t.Fatalf("newDescriptor(%d, %t, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err) - } - if tc.saveable { - if d.origFD < 0 { - t.Errorf("saveable descriptor must preserve origFD, desc: %+v", d) - } - } - if tc.wouldBlock { - if !fdnotifier.HasFD(int32(d.value)) { - t.Errorf("FD not registered with notifier, desc: %+v", d) - } - } - - oldVal := d.value - d.Release() - if d.value != -1 { - t.Errorf("d.value want: -1, got: %d", d.value) - } - if tc.wouldBlock { - if fdnotifier.HasFD(int32(oldVal)) { - t.Errorf("FD not unregistered with notifier, desc: %+v", d) - } - } - }) - } -} diff --git a/pkg/sentry/fs/host/host_amd64_unsafe_state_autogen.go b/pkg/sentry/fs/host/host_amd64_unsafe_state_autogen.go new file mode 100644 index 000000000..14cecc49f --- /dev/null +++ b/pkg/sentry/fs/host/host_amd64_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package host diff --git a/pkg/sentry/fs/host/host_arm64_unsafe_state_autogen.go b/pkg/sentry/fs/host/host_arm64_unsafe_state_autogen.go new file mode 100644 index 000000000..a714e619e --- /dev/null +++ b/pkg/sentry/fs/host/host_arm64_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package host diff --git a/pkg/sentry/fs/host/host_state_autogen.go b/pkg/sentry/fs/host/host_state_autogen.go new file mode 100644 index 000000000..b06247919 --- /dev/null +++ b/pkg/sentry/fs/host/host_state_autogen.go @@ -0,0 +1,220 @@ +// automatically generated by stateify. + +package host + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (d *descriptor) StateTypeName() string { + return "pkg/sentry/fs/host.descriptor" +} + +func (d *descriptor) StateFields() []string { + return []string{ + "origFD", + "wouldBlock", + } +} + +// +checklocksignore +func (d *descriptor) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.origFD) + stateSinkObject.Save(1, &d.wouldBlock) +} + +// +checklocksignore +func (d *descriptor) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.origFD) + stateSourceObject.Load(1, &d.wouldBlock) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (f *fileOperations) StateTypeName() string { + return "pkg/sentry/fs/host.fileOperations" +} + +func (f *fileOperations) StateFields() []string { + return []string{ + "iops", + "dirCursor", + } +} + +func (f *fileOperations) beforeSave() {} + +// +checklocksignore +func (f *fileOperations) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.iops) + stateSinkObject.Save(1, &f.dirCursor) +} + +func (f *fileOperations) afterLoad() {} + +// +checklocksignore +func (f *fileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &f.iops) + stateSourceObject.Load(1, &f.dirCursor) +} + +func (f *filesystem) StateTypeName() string { + return "pkg/sentry/fs/host.filesystem" +} + +func (f *filesystem) StateFields() []string { + return []string{} +} + +func (f *filesystem) beforeSave() {} + +// +checklocksignore +func (f *filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystem) afterLoad() {} + +// +checklocksignore +func (f *filesystem) StateLoad(stateSourceObject state.Source) { +} + +func (i *inodeOperations) StateTypeName() string { + return "pkg/sentry/fs/host.inodeOperations" +} + +func (i *inodeOperations) StateFields() []string { + return []string{ + "fileState", + "cachingInodeOps", + } +} + +func (i *inodeOperations) beforeSave() {} + +// +checklocksignore +func (i *inodeOperations) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.fileState) + stateSinkObject.Save(1, &i.cachingInodeOps) +} + +func (i *inodeOperations) afterLoad() {} + +// +checklocksignore +func (i *inodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &i.fileState) + stateSourceObject.Load(1, &i.cachingInodeOps) +} + +func (i *inodeFileState) StateTypeName() string { + return "pkg/sentry/fs/host.inodeFileState" +} + +func (i *inodeFileState) StateFields() []string { + return []string{ + "descriptor", + "sattr", + "savedUAttr", + } +} + +func (i *inodeFileState) beforeSave() {} + +// +checklocksignore +func (i *inodeFileState) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + if !state.IsZeroValue(&i.queue) { + state.Failf("queue is %#v, expected zero", &i.queue) + } + stateSinkObject.Save(0, &i.descriptor) + stateSinkObject.Save(1, &i.sattr) + stateSinkObject.Save(2, &i.savedUAttr) +} + +// +checklocksignore +func (i *inodeFileState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &i.descriptor) + stateSourceObject.LoadWait(1, &i.sattr) + stateSourceObject.Load(2, &i.savedUAttr) + stateSourceObject.AfterLoad(i.afterLoad) +} + +func (c *ConnectedEndpoint) StateTypeName() string { + return "pkg/sentry/fs/host.ConnectedEndpoint" +} + +func (c *ConnectedEndpoint) StateFields() []string { + return []string{ + "ref", + "queue", + "path", + "srfd", + "stype", + } +} + +// +checklocksignore +func (c *ConnectedEndpoint) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.ref) + stateSinkObject.Save(1, &c.queue) + stateSinkObject.Save(2, &c.path) + stateSinkObject.Save(3, &c.srfd) + stateSinkObject.Save(4, &c.stype) +} + +// +checklocksignore +func (c *ConnectedEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.ref) + stateSourceObject.Load(1, &c.queue) + stateSourceObject.Load(2, &c.path) + stateSourceObject.LoadWait(3, &c.srfd) + stateSourceObject.Load(4, &c.stype) + stateSourceObject.AfterLoad(c.afterLoad) +} + +func (t *TTYFileOperations) StateTypeName() string { + return "pkg/sentry/fs/host.TTYFileOperations" +} + +func (t *TTYFileOperations) StateFields() []string { + return []string{ + "fileOperations", + "session", + "fgProcessGroup", + "termios", + } +} + +func (t *TTYFileOperations) beforeSave() {} + +// +checklocksignore +func (t *TTYFileOperations) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.fileOperations) + stateSinkObject.Save(1, &t.session) + stateSinkObject.Save(2, &t.fgProcessGroup) + stateSinkObject.Save(3, &t.termios) +} + +func (t *TTYFileOperations) afterLoad() {} + +// +checklocksignore +func (t *TTYFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.fileOperations) + stateSourceObject.Load(1, &t.session) + stateSourceObject.Load(2, &t.fgProcessGroup) + stateSourceObject.Load(3, &t.termios) +} + +func init() { + state.Register((*descriptor)(nil)) + state.Register((*fileOperations)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*inodeOperations)(nil)) + state.Register((*inodeFileState)(nil)) + state.Register((*ConnectedEndpoint)(nil)) + state.Register((*TTYFileOperations)(nil)) +} diff --git a/pkg/sentry/fs/host/host_unsafe_state_autogen.go b/pkg/sentry/fs/host/host_unsafe_state_autogen.go new file mode 100644 index 000000000..b2d8c661f --- /dev/null +++ b/pkg/sentry/fs/host/host_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package host diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go deleted file mode 100644 index 11738871b..000000000 --- a/pkg/sentry/fs/host/inode_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sentry/contexttest" -) - -// TestCloseFD verifies fds will be closed. -func TestCloseFD(t *testing.T) { - var p [2]int - if err := unix.Pipe(p[0:]); err != nil { - t.Fatalf("Failed to create pipe %v", err) - } - defer unix.Close(p[0]) - defer unix.Close(p[1]) - - // Use the write-end because we will detect if it's closed on the read end. - ctx := contexttest.Context(t) - file, err := NewFile(ctx, p[1]) - if err != nil { - t.Fatalf("Failed to create File: %v", err) - } - file.DecRef(ctx) - - s := make([]byte, 10) - if c, err := unix.Read(p[0], s); c != 0 || err != nil { - t.Errorf("want 0, nil (EOF) from read end, got %v, %v", c, err) - } -} diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go deleted file mode 100644 index f7014b6b1..000000000 --- a/pkg/sentry/fs/host/socket_test.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "reflect" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - // Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint. - _ = transport.ConnectedEndpoint(new(ConnectedEndpoint)) - - // Make sure that ConnectedEndpoint implements transport.Receiver. - _ = transport.Receiver(new(ConnectedEndpoint)) -) - -func getFl(fd int) (uint32, error) { - fl, _, err := unix.RawSyscall(unix.SYS_FCNTL, uintptr(fd), unix.F_GETFL, 0) - if err == 0 { - return uint32(fl), nil - } - return 0, err -} - -func TestSocketIsBlocking(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - - fl, err := getFl(pair[0]) - if err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) - } - if fl&unix.O_NONBLOCK == unix.O_NONBLOCK { - t.Fatalf("Expected socket %v to be blocking", pair[0]) - } - if fl, err = getFl(pair[1]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) - } - if fl&unix.O_NONBLOCK == unix.O_NONBLOCK { - t.Fatalf("Expected socket %v to be blocking", pair[1]) - } - ctx := contexttest.Context(t) - sock, err := newSocket(ctx, pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) failed => %v", pair[0], err) - } - defer sock.DecRef(ctx) - // Test that the socket now is non-blocking. - if fl, err = getFl(pair[0]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) - } - if fl&unix.O_NONBLOCK != unix.O_NONBLOCK { - t.Errorf("Expected socket %v to have become non-blocking", pair[0]) - } - if fl, err = getFl(pair[1]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) - } - if fl&unix.O_NONBLOCK == unix.O_NONBLOCK { - t.Errorf("Did not expect socket %v to become non-blocking", pair[1]) - } -} - -func TestSocketWritev(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - ctx := contexttest.Context(t) - socket, err := newSocket(ctx, pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer socket.DecRef(ctx) - buf := []byte("hello world\n") - n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf)) - if err != nil { - t.Fatalf("socket writev failed: %v", err) - } - - if n != int64(len(buf)) { - t.Fatalf("socket writev wrote incorrect bytes: %d", n) - } -} - -func TestSocketWritevLen0(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - ctx := contexttest.Context(t) - socket, err := newSocket(ctx, pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer socket.DecRef(ctx) - n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil)) - if err != nil { - t.Fatalf("socket writev failed: %v", err) - } - - if n != 0 { - t.Fatalf("socket writev wrote incorrect bytes: %d", n) - } -} - -func TestSocketSendMsgLen0(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - ctx := contexttest.Context(t) - sfile, err := newSocket(ctx, pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer sfile.DecRef(ctx) - - s := sfile.FileOperations.(socket.Socket) - n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) - if n != 0 { - t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n) - } - - if terr != nil { - t.Fatalf("socket sendmsg() failed: %v", terr) - } -} - -func TestListen(t *testing.T) { - pair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) => %v", err) - } - ctx := contexttest.Context(t) - sfile1, err := newSocket(ctx, pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer sfile1.DecRef(ctx) - socket1 := sfile1.FileOperations.(socket.Socket) - - sfile2, err := newSocket(ctx, pair[1], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[1], err) - } - defer sfile2.DecRef(ctx) - socket2 := sfile2.FileOperations.(socket.Socket) - - // Socketpairs can not be listened to. - if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } - if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } - - // Create a Unix socket, do not bind it. - sock, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) => %v", err) - } - sfile3, err := newSocket(ctx, sock, false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", sock, err) - } - defer sfile3.DecRef(ctx) - socket3 := sfile3.FileOperations.(socket.Socket) - - // This socket is not bound so we can't listen on it. - if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } -} - -func TestPasscred(t *testing.T) { - e := &ConnectedEndpoint{} - if got, want := e.Passcred(), false; got != want { - t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) - } -} - -func TestGetLocalAddress(t *testing.T) { - e := &ConnectedEndpoint{path: "foo"} - want := tcpip.FullAddress{Addr: tcpip.Address("foo")} - if got, err := e.GetLocalAddress(); err != nil || got != want { - t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) - } -} - -func TestQueuedSize(t *testing.T) { - e := &ConnectedEndpoint{} - tests := []struct { - name string - f func() int64 - }{ - {"SendQueuedSize", e.SendQueuedSize}, - {"RecvQueuedSize", e.RecvQueuedSize}, - } - - for _, test := range tests { - if got, want := test.f(), int64(-1); got != want { - t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want) - } - } -} - -func TestRelease(t *testing.T) { - f, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - want := &ConnectedEndpoint{queue: c.queue} - ctx := contexttest.Context(t) - want.ref.DecRef(ctx) - fdnotifier.AddFD(int32(c.file.FD()), nil) - c.Release(ctx) - if !reflect.DeepEqual(c, want) { - t.Errorf("got = %#v, want = %#v", c, want) - } -} diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go deleted file mode 100644 index bd6188e03..000000000 --- a/pkg/sentry/fs/host/wait_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestWait(t *testing.T) { - var fds [2]int - err := unix.Pipe(fds[:]) - if err != nil { - t.Fatalf("Unable to create pipe: %v", err) - } - - defer unix.Close(fds[1]) - - ctx := contexttest.Context(t) - file, err := NewFile(ctx, fds[0]) - if err != nil { - unix.Close(fds[0]) - t.Fatalf("NewFile failed: %v", err) - } - - defer file.DecRef(ctx) - - r := file.Readiness(waiter.ReadableEvents) - if r != 0 { - t.Fatalf("File is ready for read when it shouldn't be.") - } - - e, ch := waiter.NewChannelEntry(nil) - file.EventRegister(&e, waiter.ReadableEvents) - defer file.EventUnregister(&e) - - // Check that there are no notifications yet. - if len(ch) != 0 { - t.Fatalf("Channel is non-empty") - } - - // Write to the pipe, so it should be writable now. - unix.Write(fds[1], []byte{1}) - - // Check that we get a notification. We need to yield the current thread - // so that the fdnotifier can deliver notifications, so we use a - // 1-second timeout instead of just checking the length of the channel. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Channel not notified") - } -} diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go deleted file mode 100644 index a3800d700..000000000 --- a/pkg/sentry/fs/inode_overlay_test.go +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -func TestLookup(t *testing.T) { - ctx := contexttest.Context(t) - for _, test := range []struct { - // Test description. - desc string - - // Lookup parameters. - dir *fs.Inode - name string - - // Want from lookup. - found bool - hasUpper bool - hasLower bool - }{ - { - desc: "no upper, lower has name", - dir: fs.NewTestOverlayDir(ctx, - nil, /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - { - desc: "no lower, upper has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - nil, /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, only lower has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "b", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - { - desc: "upper and lower, only upper has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "b", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, both have file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, both have directory", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: true, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: true, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: true, - }, - { - desc: "upper and lower, upper negative masks lower file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, nil, []string{"a"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: false, - hasUpper: false, - hasLower: false, - }, - { - desc: "upper and lower, upper negative does not mask lower file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, nil, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - } { - t.Run(test.desc, func(t *testing.T) { - dirent, err := test.dir.Lookup(ctx, test.name) - if test.found && (linuxerr.Equals(linuxerr.ENOENT, err) || dirent.IsNegative()) { - t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err) - } - if !test.found { - if !linuxerr.Equals(linuxerr.ENOENT, err) && !dirent.IsNegative() { - t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err) - } - // Nothing more to check. - return - } - if hasUpper := dirent.Inode.TestHasUpperFS(); hasUpper != test.hasUpper { - t.Fatalf("lookup got upper filesystem %v, want %v", hasUpper, test.hasUpper) - } - if hasLower := dirent.Inode.TestHasLowerFS(); hasLower != test.hasLower { - t.Errorf("lookup got lower filesystem %v, want %v", hasLower, test.hasLower) - } - }) - } -} - -func TestLookupRevalidation(t *testing.T) { - // File name used in the tests. - fileName := "foofile" - ctx := contexttest.Context(t) - for _, tc := range []struct { - // Test description. - desc string - - // Upper and lower fs for the overlay. - upper *fs.Inode - lower *fs.Inode - - // Whether the upper requires revalidation. - revalidate bool - - // Whether we should get the same dirent on second lookup. - wantSame bool - }{ - { - desc: "file from upper with no revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, nil, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from upper with revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, nil, nil), - revalidate: true, - wantSame: false, - }, - { - desc: "file from lower with no revalidation", - upper: newTestRamfsDir(ctx, nil, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from lower with revalidation", - upper: newTestRamfsDir(ctx, nil, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: true, - // The file does not exist in the upper, so we do not - // need to revalidate it. - wantSame: true, - }, - { - desc: "file from upper and lower with no revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from upper and lower with revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: true, - wantSame: false, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - root := fs.NewDirent(ctx, newTestRamfsDir(ctx, nil, nil), "root") - ctx = &rootContext{ - Context: ctx, - root: root, - } - overlay := fs.NewDirent(ctx, fs.NewTestOverlayDir(ctx, tc.upper, tc.lower, tc.revalidate), "overlay") - // Lookup the file twice through the overlay. - first, err := overlay.Walk(ctx, root, fileName) - if err != nil { - t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err) - } - second, err := overlay.Walk(ctx, root, fileName) - if err != nil { - t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err) - } - - if tc.wantSame && first != second { - t.Errorf("dirent lookup got different dirents, wanted same\nfirst=%+v\nsecond=%+v", first, second) - } else if !tc.wantSame && first == second { - t.Errorf("dirent lookup got the same dirent, wanted different: %+v", first) - } - }) - } -} - -func TestCacheFlush(t *testing.T) { - ctx := contexttest.Context(t) - - // Upper and lower each have a file. - upperFileName := "file-from-upper" - lowerFileName := "file-from-lower" - upper := newTestRamfsDir(ctx, []dirContent{{name: upperFileName}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: lowerFileName}}, nil) - - overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */) - - mns, err := fs.NewMountNamespace(ctx, overlay) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - root := mns.Root() - defer root.DecRef(ctx) - - ctx = &rootContext{ - Context: ctx, - root: root, - } - - for _, fileName := range []string{upperFileName, lowerFileName} { - // Walk to the file. - maxTraversals := uint(0) - dirent, err := mns.FindInode(ctx, root, nil, fileName, &maxTraversals) - if err != nil { - t.Fatalf("FindInode(%q) failed: %v", fileName, err) - } - - // Get a file from the dirent. - file, err := dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile() failed: %v", err) - } - - // The dirent should have 3 refs, one from us, one from the - // file, and one from the dirent cache. - // dirent cache. - if got, want := dirent.ReadRefs(), 3; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Drop the file reference. - file.DecRef(ctx) - - // Dirent should have 2 refs left. - if got, want := dirent.ReadRefs(), 2; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Flush the dirent cache. - mns.FlushMountSourceRefs() - - // Dirent should have 1 ref left from the dirent cache. - if got, want := dirent.ReadRefs(), 1; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Drop our ref. - dirent.DecRef(ctx) - - // We should be back to zero refs. - if got, want := dirent.ReadRefs(), 0; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - } - -} - -type dir struct { - fs.InodeOperations - - // List of negative child names. - negative []string - - // ReaddirCalled records whether Readdir was called on a file - // corresponding to this inode. - ReaddirCalled bool -} - -// GetXattr implements InodeOperations.GetXattr. -func (d *dir) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) { - for _, n := range d.negative { - if name == fs.XattrOverlayWhiteout(n) { - return "y", nil - } - } - return "", linuxerr.ENOATTR -} - -// GetFile implements InodeOperations.GetFile. -func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - file, err := d.InodeOperations.GetFile(ctx, dirent, flags) - if err != nil { - return nil, err - } - defer file.DecRef(ctx) - // Wrap the file's FileOperations in a dirFile. - fops := &dirFile{ - FileOperations: file.FileOperations, - inode: d, - } - return fs.NewFile(ctx, dirent, flags, fops), nil -} - -type dirContent struct { - name string - dir bool -} - -type dirFile struct { - fs.FileOperations - inode *dir -} - -type inode struct { - fsutil.InodeGenericChecker `state:"nosave"` - fsutil.InodeNoExtendedAttributes `state:"nosave"` - fsutil.InodeNoopRelease `state:"nosave"` - fsutil.InodeNoopWriteOut `state:"nosave"` - fsutil.InodeNotAllocatable `state:"nosave"` - fsutil.InodeNotDirectory `state:"nosave"` - fsutil.InodeNotMappable `state:"nosave"` - fsutil.InodeNotSocket `state:"nosave"` - fsutil.InodeNotSymlink `state:"nosave"` - fsutil.InodeNotTruncatable `state:"nosave"` - fsutil.InodeNotVirtual `state:"nosave"` - - fsutil.InodeSimpleAttributes - fsutil.InodeStaticFileGetter -} - -// Readdir implements fs.FileOperations.Readdir. It sets the ReaddirCalled -// field on the inode. -func (f *dirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) { - f.inode.ReaddirCalled = true - return f.FileOperations.Readdir(ctx, file, ser) -} - -func newTestRamfsInode(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - inode := fs.NewInode(ctx, &inode{ - InodeStaticFileGetter: fsutil.InodeStaticFileGetter{ - Contents: []byte("foobar"), - }, - }, msrc, fs.StableAttr{Type: fs.RegularFile}) - return inode -} - -func newTestRamfsDir(ctx context.Context, contains []dirContent, negative []string) *fs.Inode { - msrc := fs.NewPseudoMountSource(ctx) - contents := make(map[string]*fs.Inode) - for _, c := range contains { - if c.dir { - contents[c.name] = newTestRamfsDir(ctx, nil, nil) - } else { - contents[c.name] = newTestRamfsInode(ctx, msrc) - } - } - dops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermissions{ - User: fs.PermMask{Read: true, Execute: true}, - }) - return fs.NewInode(ctx, &dir{ - InodeOperations: dops, - negative: negative, - }, msrc, fs.StableAttr{Type: fs.Directory}) -} diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD deleted file mode 100644 index c09d8463b..000000000 --- a/pkg/sentry/fs/lock/BUILD +++ /dev/null @@ -1,62 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "lock_range", - out = "lock_range.go", - package = "lock", - prefix = "Lock", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "lock_set", - out = "lock_set.go", - consts = { - "minDegree": "3", - }, - package = "lock", - prefix = "Lock", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "LockRange", - "Value": "Lock", - "Functions": "lockSetFunctions", - }, -) - -go_library( - name = "lock", - srcs = [ - "lock.go", - "lock_range.go", - "lock_set.go", - "lock_set_functions.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/log", - "//pkg/sync", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "lock_test", - size = "small", - srcs = [ - "lock_range_test.go", - "lock_test.go", - ], - library = ":lock", - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/sentry/fs/lock/lock_range.go b/pkg/sentry/fs/lock/lock_range.go new file mode 100644 index 000000000..13a2cce6e --- /dev/null +++ b/pkg/sentry/fs/lock/lock_range.go @@ -0,0 +1,76 @@ +package lock + +// A Range represents a contiguous range of T. +// +// +stateify savable +type LockRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +// +//go:nosplit +func (r LockRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +// +//go:nosplit +func (r LockRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +// +//go:nosplit +func (r LockRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +// +//go:nosplit +func (r LockRange) Overlaps(r2 LockRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +// +//go:nosplit +func (r LockRange) IsSupersetOf(r2 LockRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +// +//go:nosplit +func (r LockRange) Intersect(r2 LockRange) LockRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +// +//go:nosplit +func (r LockRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/fs/lock/lock_range_test.go b/pkg/sentry/fs/lock/lock_range_test.go deleted file mode 100644 index 2b6f8630b..000000000 --- a/pkg/sentry/fs/lock/lock_range_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lock - -import ( - "testing" - - "golang.org/x/sys/unix" -) - -func TestComputeRange(t *testing.T) { - tests := []struct { - // Description of test. - name string - - // Requested start of the lock range. - start int64 - - // Requested length of the lock range, - // can be negative :( - length int64 - - // Pre-computed file offset based on whence. - // Will be added to start. - offset int64 - - // Expected error. - err error - - // If error is nil, the expected LockRange. - LockRange - }{ - { - name: "offset, start, and length all zero", - LockRange: LockRange{Start: 0, End: LockEOF}, - }, - { - name: "zero offset, zero start, positive length", - start: 0, - length: 4096, - offset: 0, - LockRange: LockRange{Start: 0, End: 4096}, - }, - { - name: "zero offset, negative start", - start: -4096, - offset: 0, - err: unix.EINVAL, - }, - { - name: "large offset, negative start, positive length", - start: -2048, - length: 2048, - offset: 4096, - LockRange: LockRange{Start: 2048, End: 4096}, - }, - { - name: "large offset, negative start, zero length", - start: -2048, - length: 0, - offset: 4096, - LockRange: LockRange{Start: 2048, End: LockEOF}, - }, - { - name: "zero offset, zero start, negative length", - start: 0, - length: -4096, - offset: 0, - err: unix.EINVAL, - }, - { - name: "large offset, zero start, negative length", - start: 0, - length: -4096, - offset: 4096, - LockRange: LockRange{Start: 0, End: 4096}, - }, - { - name: "offset, start, and length equal, length is negative", - start: 1024, - length: -1024, - offset: 1024, - LockRange: LockRange{Start: 1024, End: 2048}, - }, - { - name: "offset, start, and length equal, start is negative", - start: -1024, - length: 1024, - offset: 1024, - LockRange: LockRange{Start: 0, End: 1024}, - }, - { - name: "offset, start, and length equal, offset is negative", - start: 1024, - length: 1024, - offset: -1024, - LockRange: LockRange{Start: 0, End: 1024}, - }, - { - name: "offset, start, and length equal, all negative", - start: -1024, - length: -1024, - offset: -1024, - err: unix.EINVAL, - }, - { - name: "offset, start, and length equal, all positive", - start: 1024, - length: 1024, - offset: 1024, - LockRange: LockRange{Start: 2048, End: 3072}, - }, - } - - for _, test := range tests { - rng, err := ComputeRange(test.start, test.length, test.offset) - if err != test.err { - t.Errorf("%s: lockRange(%d, %d, %d) got error %v, want %v", test.name, test.start, test.length, test.offset, err, test.err) - continue - } - if err == nil && rng != test.LockRange { - t.Errorf("%s: lockRange(%d, %d, %d) got LockRange %v, want %v", test.name, test.start, test.length, test.offset, rng, test.LockRange) - } - } -} diff --git a/pkg/sentry/fs/lock/lock_set.go b/pkg/sentry/fs/lock/lock_set.go new file mode 100644 index 000000000..4bc830883 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_set.go @@ -0,0 +1,1643 @@ +package lock + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const LocktrackGaps = 0 + +var _ = uint8(LocktrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type LockdynamicGap [LocktrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *LockdynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *LockdynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + LockminDegree = 3 + + LockmaxDegree = 2 * LockminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type LockSet struct { + root Locknode `state:".(*LockSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *LockSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *LockSet) IsEmptyRange(r LockRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *LockSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *LockSet) SpanRange(r LockRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *LockSet) FirstSegment() LockIterator { + if s.root.nrSegments == 0 { + return LockIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *LockSet) LastSegment() LockIterator { + if s.root.nrSegments == 0 { + return LockIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *LockSet) FirstGap() LockGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return LockGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *LockSet) LastGap() LockGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return LockGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *LockSet) Find(key uint64) (LockIterator, LockGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return LockIterator{n, i}, LockGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return LockIterator{}, LockGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *LockSet) FindSegment(key uint64) LockIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *LockSet) LowerBoundSegment(min uint64) LockIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *LockSet) UpperBoundSegment(max uint64) LockIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *LockSet) FindGap(key uint64) LockGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *LockSet) LowerBoundGap(min uint64) LockGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *LockSet) UpperBoundGap(max uint64) LockGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *LockSet) Add(r LockRange, val Lock) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *LockSet) AddWithoutMerging(r LockRange, val Lock) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *LockSet) Insert(gap LockGapIterator, r LockRange, val Lock) LockIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := LocktrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (lockSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := LocktrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *LockSet) InsertWithoutMerging(gap LockGapIterator, r LockRange, val Lock) LockIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *LockSet) InsertWithoutMergingUnchecked(gap LockGapIterator, r LockRange, val Lock) LockIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := LocktrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return LockIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *LockSet) Remove(seg LockIterator) LockGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if LocktrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + lockSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if LocktrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(LockGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *LockSet) RemoveAll() { + s.root = Locknode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *LockSet) RemoveRange(r LockRange) LockGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *LockSet) Merge(first, second LockIterator) LockIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *LockSet) MergeUnchecked(first, second LockIterator) LockIterator { + if first.End() == second.Start() { + if mval, ok := (lockSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return LockIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *LockSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *LockSet) MergeRange(r LockRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *LockSet) MergeAdjacent(r LockRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *LockSet) Split(seg LockIterator, split uint64) (LockIterator, LockIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *LockSet) SplitUnchecked(seg LockIterator, split uint64) (LockIterator, LockIterator) { + val1, val2 := (lockSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), LockRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *LockSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *LockSet) Isolate(seg LockIterator, r LockRange) LockIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *LockSet) ApplyContiguous(r LockRange, fn func(seg LockIterator)) LockGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return LockGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return LockGapIterator{} + } + } +} + +// +stateify savable +type Locknode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Locknode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap LockdynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [LockmaxDegree - 1]LockRange + values [LockmaxDegree - 1]Lock + children [LockmaxDegree]*Locknode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Locknode) firstSegment() LockIterator { + for n.hasChildren { + n = n.children[0] + } + return LockIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Locknode) lastSegment() LockIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return LockIterator{n, n.nrSegments - 1} +} + +func (n *Locknode) prevSibling() *Locknode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Locknode) nextSibling() *Locknode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Locknode) rebalanceBeforeInsert(gap LockGapIterator) LockGapIterator { + if n.nrSegments < LockmaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:LockminDegree-1], n.keys[:LockminDegree-1]) + copy(left.values[:LockminDegree-1], n.values[:LockminDegree-1]) + copy(right.keys[:LockminDegree-1], n.keys[LockminDegree:]) + copy(right.values[:LockminDegree-1], n.values[LockminDegree:]) + n.keys[0], n.values[0] = n.keys[LockminDegree-1], n.values[LockminDegree-1] + LockzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:LockminDegree], n.children[:LockminDegree]) + copy(right.children[:LockminDegree], n.children[LockminDegree:]) + LockzeroNodeSlice(n.children[2:]) + for i := 0; i < LockminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if LocktrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < LockminDegree { + return LockGapIterator{left, gap.index} + } + return LockGapIterator{right, gap.index - LockminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[LockminDegree-1], n.values[LockminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:LockminDegree-1], n.keys[LockminDegree:]) + copy(sibling.values[:LockminDegree-1], n.values[LockminDegree:]) + LockzeroValueSlice(n.values[LockminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:LockminDegree], n.children[LockminDegree:]) + LockzeroNodeSlice(n.children[LockminDegree:]) + for i := 0; i < LockminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = LockminDegree - 1 + + if LocktrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < LockminDegree { + return gap + } + return LockGapIterator{sibling, gap.index - LockminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Locknode) rebalanceAfterRemove(gap LockGapIterator) LockGapIterator { + for { + if n.nrSegments >= LockminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= LockminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if LocktrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return LockGapIterator{n, 0} + } + if gap.node == n { + return LockGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= LockminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if LocktrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return LockGapIterator{n, n.nrSegments} + } + return LockGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return LockGapIterator{p, gap.index} + } + if gap.node == right { + return LockGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Locknode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = LockGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + lockSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if LocktrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *Locknode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *Locknode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *Locknode) calculateMaxGapLeaf() uint64 { + max := LockGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (LockGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *Locknode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *Locknode) searchFirstLargeEnoughGap(minSize uint64) LockGapIterator { + if n.maxGap.Get() < minSize { + return LockGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := LockGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *Locknode) searchLastLargeEnoughGap(minSize uint64) LockGapIterator { + if n.maxGap.Get() < minSize { + return LockGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := LockGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type LockIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Locknode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg LockIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg LockIterator) Range() LockRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg LockIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg LockIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg LockIterator) SetRangeUnchecked(r LockRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetRange(r LockRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg LockIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg LockIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg LockIterator) Value() Lock { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg LockIterator) ValuePtr() *Lock { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg LockIterator) SetValue(val Lock) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg LockIterator) PrevSegment() LockIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return LockIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return LockIterator{} + } + return LocksegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg LockIterator) NextSegment() LockIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return LockIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return LockIterator{} + } + return LocksegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg LockIterator) PrevGap() LockGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return LockGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg LockIterator) NextGap() LockGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return LockGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg LockIterator) PrevNonEmpty() (LockIterator, LockGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return LockIterator{}, gap + } + return gap.PrevSegment(), LockGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg LockIterator) NextNonEmpty() (LockIterator, LockGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return LockIterator{}, gap + } + return gap.NextSegment(), LockGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type LockGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Locknode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap LockGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap LockGapIterator) Range() LockRange { + return LockRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap LockGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return lockSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap LockGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return lockSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap LockGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap LockGapIterator) PrevSegment() LockIterator { + return LocksegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap LockGapIterator) NextSegment() LockIterator { + return LocksegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap LockGapIterator) PrevGap() LockGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return LockGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap LockGapIterator) NextGap() LockGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return LockGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap LockGapIterator) NextLargeEnoughGap(minSize uint64) LockGapIterator { + if LocktrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap LockGapIterator) nextLargeEnoughGapHelper(minSize uint64) LockGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return LockGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap LockGapIterator) PrevLargeEnoughGap(minSize uint64) LockGapIterator { + if LocktrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap LockGapIterator) prevLargeEnoughGapHelper(minSize uint64) LockGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return LockGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func LocksegmentBeforePosition(n *Locknode, i int) LockIterator { + for i == 0 { + if n.parent == nil { + return LockIterator{} + } + n, i = n.parent, n.parentIndex + } + return LockIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func LocksegmentAfterPosition(n *Locknode, i int) LockIterator { + for i == n.nrSegments { + if n.parent == nil { + return LockIterator{} + } + n, i = n.parent, n.parentIndex + } + return LockIterator{n, i} +} + +func LockzeroValueSlice(slice []Lock) { + + for i := range slice { + lockSetFunctions{}.ClearValue(&slice[i]) + } +} + +func LockzeroNodeSlice(slice []*Locknode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *LockSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *Locknode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Locknode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if LocktrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type LockSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []Lock +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *LockSet) ExportSortedSlices() *LockSegmentDataSlices { + var sds LockSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *LockSet) ImportSortedSlices(sds *LockSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := LockRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *LockSet) segmentTestCheck(expectedSegments int, segFunc func(int, LockRange, Lock) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *LockSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *LockSet) saveRoot() *LockSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *LockSet) loadRoot(sds *LockSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/lock/lock_state_autogen.go b/pkg/sentry/fs/lock/lock_state_autogen.go new file mode 100644 index 000000000..4a9d0a1d8 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_state_autogen.go @@ -0,0 +1,233 @@ +// automatically generated by stateify. + +package lock + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (o *OwnerInfo) StateTypeName() string { + return "pkg/sentry/fs/lock.OwnerInfo" +} + +func (o *OwnerInfo) StateFields() []string { + return []string{ + "PID", + } +} + +func (o *OwnerInfo) beforeSave() {} + +// +checklocksignore +func (o *OwnerInfo) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.PID) +} + +func (o *OwnerInfo) afterLoad() {} + +// +checklocksignore +func (o *OwnerInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.PID) +} + +func (l *Lock) StateTypeName() string { + return "pkg/sentry/fs/lock.Lock" +} + +func (l *Lock) StateFields() []string { + return []string{ + "Readers", + "Writer", + "WriterInfo", + } +} + +func (l *Lock) beforeSave() {} + +// +checklocksignore +func (l *Lock) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Readers) + stateSinkObject.Save(1, &l.Writer) + stateSinkObject.Save(2, &l.WriterInfo) +} + +func (l *Lock) afterLoad() {} + +// +checklocksignore +func (l *Lock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Readers) + stateSourceObject.Load(1, &l.Writer) + stateSourceObject.Load(2, &l.WriterInfo) +} + +func (l *Locks) StateTypeName() string { + return "pkg/sentry/fs/lock.Locks" +} + +func (l *Locks) StateFields() []string { + return []string{ + "locks", + } +} + +func (l *Locks) beforeSave() {} + +// +checklocksignore +func (l *Locks) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + if !state.IsZeroValue(&l.blockedQueue) { + state.Failf("blockedQueue is %#v, expected zero", &l.blockedQueue) + } + stateSinkObject.Save(0, &l.locks) +} + +func (l *Locks) afterLoad() {} + +// +checklocksignore +func (l *Locks) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.locks) +} + +func (r *LockRange) StateTypeName() string { + return "pkg/sentry/fs/lock.LockRange" +} + +func (r *LockRange) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (r *LockRange) beforeSave() {} + +// +checklocksignore +func (r *LockRange) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) +} + +func (r *LockRange) afterLoad() {} + +// +checklocksignore +func (r *LockRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) +} + +func (s *LockSet) StateTypeName() string { + return "pkg/sentry/fs/lock.LockSet" +} + +func (s *LockSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *LockSet) beforeSave() {} + +// +checklocksignore +func (s *LockSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *LockSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *LockSet) afterLoad() {} + +// +checklocksignore +func (s *LockSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*LockSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*LockSegmentDataSlices)) }) +} + +func (n *Locknode) StateTypeName() string { + return "pkg/sentry/fs/lock.Locknode" +} + +func (n *Locknode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *Locknode) beforeSave() {} + +// +checklocksignore +func (n *Locknode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *Locknode) afterLoad() {} + +// +checklocksignore +func (n *Locknode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (l *LockSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/fs/lock.LockSegmentDataSlices" +} + +func (l *LockSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (l *LockSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (l *LockSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Start) + stateSinkObject.Save(1, &l.End) + stateSinkObject.Save(2, &l.Values) +} + +func (l *LockSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (l *LockSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Start) + stateSourceObject.Load(1, &l.End) + stateSourceObject.Load(2, &l.Values) +} + +func init() { + state.Register((*OwnerInfo)(nil)) + state.Register((*Lock)(nil)) + state.Register((*Locks)(nil)) + state.Register((*LockRange)(nil)) + state.Register((*LockSet)(nil)) + state.Register((*Locknode)(nil)) + state.Register((*LockSegmentDataSlices)(nil)) +} diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go deleted file mode 100644 index 9878c04e1..000000000 --- a/pkg/sentry/fs/lock/lock_test.go +++ /dev/null @@ -1,1060 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lock - -import ( - "reflect" - "testing" -) - -type entry struct { - Lock - LockRange -} - -func equals(e0, e1 []entry) bool { - if len(e0) != len(e1) { - return false - } - for i := range e0 { - for k := range e0[i].Lock.Readers { - if _, ok := e1[i].Lock.Readers[k]; !ok { - return false - } - } - for k := range e1[i].Lock.Readers { - if _, ok := e0[i].Lock.Readers[k]; !ok { - return false - } - } - if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) { - return false - } - if e0[i].Lock.Writer != e1[i].Lock.Writer { - return false - } - } - return true -} - -// fill a LockSet with consecutive region locks. Will panic if -// LockRanges are not consecutive. -func fill(entries []entry) LockSet { - l := LockSet{} - for _, e := range entries { - gap := l.FindGap(e.LockRange.Start) - if !gap.Ok() { - panic("cannot insert into existing segment") - } - l.Insert(gap, e.LockRange, e.Lock) - } - return l -} - -func TestCanLockEmpty(t *testing.T) { - l := LockSet{} - - // Expect to be able to take any locks given that the set is empty. - eof := l.FirstGap().End() - r := LockRange{0, eof} - if !l.canLock(1, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1) - } - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - if !l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2) - } -} - -func TestCanLock(t *testing.T) { - // + -------------- + ---------- + -------------- + --------- + - // | Readers 1 & 2 | Readers 1 | Readers 1 & 3 | Writer 1 | - // + ------------- + ---------- + -------------- + --------- + - // 0 1024 2048 3072 4096 - l := fill([]entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 3: OwnerInfo{}}}, - LockRange: LockRange{2048, 3072}, - }, - { - Lock: Lock{Writer: 1}, - LockRange: LockRange{3072, 4096}, - }, - }) - - // Now that we have a mildly interesting layout, try some checks on different - // ranges, uids, and lock types. - // - // Expect to be able to extend the read lock, despite the writer lock, because - // the writer has the same uid as the requested read lock. - r := LockRange{0, 8192} - if !l.canLock(1, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1) - } - // Expect to *not* be able to extend the read lock since there is an overlapping - // writer region locked by someone other than the uid. - if l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", ReadLock, r, 2) - } - // Expect to be able to extend the read lock if there are only other readers in - // the way. - r = LockRange{64, 3072} - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - // Expect to be able to set a read lock beyond the range of any existing locks. - r = LockRange{4096, 10240} - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - - // Expect to not be able to take a write lock with other readers in the way. - r = LockRange{0, 8192} - if l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 1) - } - // Expect to be able to extend the write lock for the same uid. - r = LockRange{3072, 8192} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - // Expect to not be able to overlap a write lock for two different uids. - if l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 2) - } - // Expect to be able to set a write lock that is beyond the range of any - // existing locks. - r = LockRange{8192, 10240} - if !l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2) - } - // Expect to be able to upgrade a read lock (any portion of it). - r = LockRange{1024, 2048} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - r = LockRange{1080, 2000} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } -} - -func TestSetLock(t *testing.T) { - tests := []struct { - // description of test. - name string - - // LockSet entries to pre-fill. - before []entry - - // Description of region to lock: - // - // start is the file offset of the lock. - start uint64 - // end is the end file offset of the lock. - end uint64 - // uid of lock attempter. - uid UniqueID - // lock type requested. - lockType LockType - - // success is true if taking the above - // lock should succeed. - success bool - - // Expected layout of the set after locking - // if success is true. - after []entry - }{ - { - name: "set zero length ReadLock on empty set", - start: 0, - end: 0, - uid: 0, - lockType: ReadLock, - success: true, - }, - { - name: "set zero length WriteLock on empty set", - start: 0, - end: 0, - uid: 0, - lockType: WriteLock, - success: true, - }, - { - name: "set ReadLock on empty set", - start: 0, - end: LockEOF, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "set WriteLock on empty set", - start: 0, - end: LockEOF, - uid: 0, - lockType: WriteLock, - success: true, - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "set ReadLock on WriteLock same uid", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------- + --------------------------- + - // | Readers 0 | Writer 0 | - // + ----------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "set WriteLock on ReadLock same uid", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - lockType: WriteLock, - success: true, - // + ----------- + --------------------------- + - // | Writer 0 | Readers 0 | - // + ----------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "set ReadLock on WriteLock different uid", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: false, - }, - { - name: "set WriteLock on ReadLock different uid", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: WriteLock, - success: false, - }, - { - name: "split ReadLock for overlapping lock at start 0", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: true, - // + -------------- + --------------------------- + - // | Readers 0 & 1 | Readers 0 | - // + -------------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "split ReadLock for overlapping lock at non-zero start", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: 8192, - uid: 1, - lockType: ReadLock, - success: true, - // + ---------- + -------------- + ----------- + - // | Readers 0 | Readers 0 & 1 | Readers 0 | - // + ---------- + -------------- + ----------- + - // 0 4096 8192 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{4096, 8192}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{8192, LockEOF}, - }, - }, - }, - { - name: "fill front gap with ReadLock", - // + --------- + ---------------------------- + - // | gap | Readers 0 | - // + --------- + ---------------------------- + - // 0 1024 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - start: 0, - end: 8192, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "fill end gap with ReadLock", - // + ---------------------------- + - // | Readers 0 | - // + ---------------------------- + - // 0 4096 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, 4096}, - }, - }, - start: 1024, - end: LockEOF, - uid: 0, - lockType: ReadLock, - success: true, - // Note that this is not merged after lock does a Split. This is - // fine because the two locks will still *behave* as one. In other - // words we can fragment any lock all we want and semantically it - // makes no difference. - // - // + ----------- + --------------------------- + - // | Readers 0 | Readers 0 | - // + ----------- + --------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - }, - { - name: "fill gap with ReadLock and split", - // + --------- + ---------------------------- + - // | gap | Readers 0 | - // + --------- + ---------------------------- + - // 0 1024 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: true, - // + --------- + ------------- + ------------- + - // | Reader 1 | Readers 0 & 1 | Reader 0 | - // + ----------+ ------------- + ------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "upgrade ReadLock to WriteLock for single uid fill gap", - // + ------------- + --------- + --- + ------------- + - // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 | - // + ------------- + --------- + --- + ------------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - lockType: WriteLock, - success: true, - // + ------------- + -------- + ------------- + - // | Readers 0 & 1 | Writer 0 | Readers 0 & 2 | - // + ------------- + -------- + ------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "upgrade ReadLock to WriteLock for single uid keep gap", - // + ------------- + --------- + --- + ------------- + - // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 | - // + ------------- + --------- + --- + ------------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 1024, - end: 3072, - uid: 0, - lockType: WriteLock, - success: true, - // + ------------- + -------- + --- + ------------- + - // | Readers 0 & 1 | Writer 0 | gap | Readers 0 & 2 | - // + ------------- + -------- + --- + ------------- + - // 0 1024 3072 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{1024, 3072}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "fail to upgrade ReadLock to WriteLock with conflicting Reader", - // + ------------- + --------- + - // | Readers 0 & 1 | Readers 0 | - // + ------------- + --------- + - // 0 1024 2048 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, 2048}, - }, - }, - start: 0, - end: 2048, - uid: 0, - lockType: WriteLock, - success: false, - }, - { - name: "take WriteLock on whole file if all uids are the same", - // + ------------- + --------- + --------- + ---------- + - // | Writer 0 | Readers 0 | Readers 0 | Readers 0 | - // + ------------- + --------- + --------- + ---------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{2048, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - lockType: WriteLock, - success: true, - // We do not manually merge locks. Semantically a fragmented lock - // held by the same uid will behave as one lock so it makes no difference. - // - // + ------------- + ---------------------------- + - // | Writer 0 | Writer 0 | - // + ------------- + ---------------------------- + - // 0 1024 max uint64 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - l := fill(test.before) - - r := LockRange{Start: test.start, End: test.end} - success := l.lock(test.uid, 0 /* ownerPID */, test.lockType, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - - if success != test.success { - t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success) - return - } - - if success { - if !equals(got, test.after) { - t.Errorf("got set %+v, want %+v", got, test.after) - } - } - }) - } -} - -func TestUnlock(t *testing.T) { - tests := []struct { - // description of test. - name string - - // LockSet entries to pre-fill. - before []entry - - // Description of region to unlock: - // - // start is the file start of the lock. - start uint64 - // end is the end file start of the lock. - end uint64 - // uid of lock holder. - uid UniqueID - - // Expected layout of the set after unlocking. - after []entry - }{ - { - name: "unlock zero length on empty set", - start: 0, - end: 0, - uid: 0, - }, - { - name: "unlock on empty set (no-op)", - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock uid not locked (no-op)", - // + --------------------------- + - // | Readers 1 & 2 | - // + --------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - // + --------------------------- + - // | Readers 1 & 2 | - // + --------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "unlock ReadLock over entire file", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock WriteLock over entire file", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock partial ReadLock (start)", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - // + ------ + --------------------------- + - // | gap | Readers 0 | - // +------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock partial WriteLock (start)", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - // + ------ + --------------------------- + - // | gap | Writer 0 | - // +------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock partial ReadLock (end)", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: LockEOF, - uid: 0, - // + --------------------------- + - // | Readers 0 | - // +---------------------------- + - // 0 4096 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, - LockRange: LockRange{0, 4096}, - }, - }, - }, - { - name: "unlock partial WriteLock (end)", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: LockEOF, - uid: 0, - // + --------------------------- + - // | Writer 0 | - // +---------------------------- + - // 0 4096 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 4096}, - }, - }, - }, - { - name: "unlock for single uid", - // + ------------- + --------- + ------------------- + - // | Readers 0 & 1 | Writer 0 | Readers 0 & 1 & 2 | - // + ------------- + --------- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - // + --------- + --- + --------------- + - // | Readers 1 | gap | Readers 1 & 2 | - // + --------- + --- + --------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock subsection locked", - // + ------------------------------- + - // | Readers 0 & 1 & 2 | - // + ------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - // + ----------------- + ------------- + ----------------- + - // | Readers 0 & 1 & 2 | Readers 1 & 2 | Readers 0 & 1 & 2 | - // + ----------------- + ------------- + ----------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock mid-gap to increase gap", - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 8, - end: 2048, - uid: 0, - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 8 4096 max uint64 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 8}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock split region on uid mid-gap", - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 2048, - end: 8192, - uid: 0, - // + --------- + ----- + --------- + ------------- + - // | Writer 0 | gap | Readers 1 | Readers 0 & 1 | - // + --------- + ----- + --------- + ------------- + - // 0 1024 4096 8192 max uint64 - after: []entry{ - { - Lock: Lock{Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, - LockRange: LockRange{4096, 8192}, - }, - { - Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, - LockRange: LockRange{8192, LockEOF}, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - l := fill(test.before) - - r := LockRange{Start: test.start, End: test.end} - l.unlock(test.uid, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - if !equals(got, test.after) { - t.Errorf("got set %+v, want %+v", got, test.after) - } - }) - } -} diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go deleted file mode 100644 index 6c296f5d0..000000000 --- a/pkg/sentry/fs/mount_test.go +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" -) - -// cacheReallyContains iterates through the dirent cache to determine whether -// it contains the given dirent. -func cacheReallyContains(cache *DirentCache, d *Dirent) bool { - for i := cache.list.Front(); i != nil; i = i.Next() { - if i == d { - return true - } - } - return false -} - -func mountPathsAre(ctx context.Context, root *Dirent, got []*Mount, want ...string) error { - gotPaths := make(map[string]struct{}, len(got)) - gotStr := make([]string, len(got)) - for i, g := range got { - if groot := g.Root(); groot != nil { - name, _ := groot.FullName(root) - groot.DecRef(ctx) - gotStr[i] = name - gotPaths[name] = struct{}{} - } - } - if len(got) != len(want) { - return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want) - } - for _, w := range want { - if _, ok := gotPaths[w]; !ok { - return fmt.Errorf("no mount with path %q found", w) - } - } - return nil -} - -// TestMountSourceOnlyCachedOnce tests that a Dirent that is mounted over only ends -// up in a single Dirent Cache. NOTE(b/63848693): Having a dirent in multiple -// caches causes major consistency issues. -func TestMountSourceOnlyCachedOnce(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef(ctx) - - // Get a child of the root which we will mount over. Note that the - // MockInodeOperations causes Walk to always succeed. - child, err := rootDirent.Walk(ctx, rootDirent, "child") - if err != nil { - t.Fatalf("failed to walk to child dirent: %v", err) - } - child.maybeExtendReference() // Cache. - - // Ensure that the root cache contains the child. - if !cacheReallyContains(rootCache, child) { - t.Errorf("wanted rootCache to contain child dirent, but it did not") - } - - // Create a new cache and inode, and mount it over child. - submountCache := NewDirentCache(100) - submountInode := NewMockInode(ctx, NewMockMountSource(submountCache), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, child, submountInode); err != nil { - t.Fatalf("failed to mount over child: %v", err) - } - - // Walk to the child again. - child2, err := rootDirent.Walk(ctx, rootDirent, "child") - if err != nil { - t.Fatalf("failed to walk to child dirent: %v", err) - } - - // Should have a different Dirent than before. - if child == child2 { - t.Fatalf("expected %v not equal to %v, but they are the same", child, child2) - } - - // Neither of the caches should no contain the child. - if cacheReallyContains(rootCache, child) { - t.Errorf("wanted rootCache not to contain child dirent, but it did") - } - if cacheReallyContains(submountCache, child) { - t.Errorf("wanted submountCache not to contain child dirent, but it did") - } -} - -func TestAllMountsUnder(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef(ctx) - - // Add mounts at the following paths: - paths := []string{ - "/foo", - "/foo/bar", - "/foo/bar/baz", - "/foo/qux", - "/waldo", - } - - var maxTraversals uint - for _, p := range paths { - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, d, submountInode); err != nil { - t.Fatalf("could not mount at %q: %v", p, err) - } - d.DecRef(ctx) - } - - // mm root should contain all submounts (and does not include the root mount). - rootMnt := mm.FindMount(rootDirent) - submounts := mm.AllMountsUnder(rootMnt) - allPaths := append(paths, "/") - if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil { - t.Error(err) - } - - // Each mount should have a unique ID. - foundIDs := make(map[uint64]struct{}) - for _, m := range submounts { - if _, ok := foundIDs[m.ID]; ok { - t.Errorf("got multiple mounts with id %d", m.ID) - } - foundIDs[m.ID] = struct{}{} - } - - // Root mount should have no parent. - if p := rootMnt.ParentID; p != invalidMountID { - t.Errorf("root.Parent got %v wanted nil", p) - } - - // Check that "foo" mount has 3 children. - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, "/foo", &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", "/foo", err) - } - defer d.DecRef(ctx) - submounts = mm.AllMountsUnder(mm.FindMount(d)) - if err := mountPathsAre(ctx, rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { - t.Error(err) - } - - // "waldo" mount should have no children. - maxTraversals = 0 - waldo, err := mm.FindLink(ctx, rootDirent, nil, "/waldo", &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err) - } - defer waldo.DecRef(ctx) - submounts = mm.AllMountsUnder(mm.FindMount(waldo)) - if err := mountPathsAre(ctx, rootDirent, submounts, "/waldo"); err != nil { - t.Error(err) - } -} - -func TestUnmount(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef(ctx) - - // Add mounts at the following paths: - paths := []string{ - "/foo", - "/foo/bar", - "/foo/bar/goo", - "/foo/bar/goo/abc", - "/foo/abc", - "/foo/def", - "/waldo", - "/wally", - } - - var maxTraversals uint - for _, p := range paths { - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, d, submountInode); err != nil { - t.Fatalf("could not mount at %q: %v", p, err) - } - d.DecRef(ctx) - } - - allPaths := make([]string, len(paths)+1) - allPaths[0] = "/" - copy(allPaths[1:], paths) - - rootMnt := mm.FindMount(rootDirent) - for i := len(paths) - 1; i >= 0; i-- { - maxTraversals = 0 - p := paths[i] - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - if err := mm.Unmount(ctx, d, false); err != nil { - t.Fatalf("could not unmount at %q: %v", p, err) - } - d.DecRef(ctx) - - // Remove the path that has been unmounted and the check that the remaining - // mounts are still there. - allPaths = allPaths[:len(allPaths)-1] - submounts := mm.AllMountsUnder(rootMnt) - if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil { - t.Error(err) - } - } -} diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go deleted file mode 100644 index 975d6cbc9..000000000 --- a/pkg/sentry/fs/mounts_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -// Creates a new MountNamespace with filesystem: -// / (root dir) -// |-foo (dir) -// |-bar (file) -func createMountNamespace(ctx context.Context) (*fs.MountNamespace, error) { - perms := fs.FilePermsFromMode(0777) - m := fs.NewPseudoMountSource(ctx) - - barFile := fsutil.NewSimpleFileInode(ctx, fs.RootOwner, perms, 0) - fooDir := ramfs.NewDir(ctx, map[string]*fs.Inode{ - "bar": fs.NewInode(ctx, barFile, m, fs.StableAttr{Type: fs.RegularFile}), - }, fs.RootOwner, perms) - rootDir := ramfs.NewDir(ctx, map[string]*fs.Inode{ - "foo": fs.NewInode(ctx, fooDir, m, fs.StableAttr{Type: fs.Directory}), - }, fs.RootOwner, perms) - - return fs.NewMountNamespace(ctx, fs.NewInode(ctx, rootDir, m, fs.StableAttr{Type: fs.Directory})) -} - -func TestFindLink(t *testing.T) { - ctx := contexttest.Context(t) - mm, err := createMountNamespace(ctx) - if err != nil { - t.Fatalf("createMountNamespace failed: %v", err) - } - - root := mm.Root() - defer root.DecRef(ctx) - foo, err := root.Walk(ctx, root, "foo") - if err != nil { - t.Fatalf("Error walking to foo: %v", err) - } - - // Positive cases. - for _, tc := range []struct { - findPath string - wd *fs.Dirent - wantPath string - }{ - {".", root, "/"}, - {".", foo, "/foo"}, - {"..", foo, "/"}, - {"../../..", foo, "/"}, - {"///foo", foo, "/foo"}, - {"/foo", foo, "/foo"}, - {"/foo/bar", foo, "/foo/bar"}, - {"/foo/.///./bar", foo, "/foo/bar"}, - {"/foo///bar", foo, "/foo/bar"}, - {"/foo/../foo/bar", foo, "/foo/bar"}, - {"foo/bar", root, "/foo/bar"}, - {"foo////bar", root, "/foo/bar"}, - {"bar", foo, "/foo/bar"}, - } { - wdPath, _ := tc.wd.FullName(root) - maxTraversals := uint(0) - if d, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err != nil { - t.Errorf("FindLink(%q, wd=%q) failed: %v", tc.findPath, wdPath, err) - } else if got, _ := d.FullName(root); got != tc.wantPath { - t.Errorf("FindLink(%q, wd=%q) got dirent %q, want %q", tc.findPath, wdPath, got, tc.wantPath) - } - } - - // Negative cases. - for _, tc := range []struct { - findPath string - wd *fs.Dirent - }{ - {"bar", root}, - {"/bar", root}, - {"/foo/../../bar", root}, - {"foo", foo}, - } { - wdPath, _ := tc.wd.FullName(root) - maxTraversals := uint(0) - if _, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err == nil { - t.Errorf("FindLink(%q, wd=%q) did not return error", tc.findPath, wdPath) - } - } -} diff --git a/pkg/sentry/fs/path_test.go b/pkg/sentry/fs/path_test.go deleted file mode 100644 index e6f57ebba..000000000 --- a/pkg/sentry/fs/path_test.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" -) - -// TestSplitLast tests variants of path splitting. -func TestSplitLast(t *testing.T) { - cases := []struct { - path string - dir string - file string - }{ - {path: "/", dir: "/", file: "."}, - {path: "/.", dir: "/", file: "."}, - {path: "/./", dir: "/", file: "."}, - {path: "/./.", dir: "/.", file: "."}, - {path: "/././", dir: "/.", file: "."}, - {path: "/./..", dir: "/.", file: ".."}, - {path: "/./../", dir: "/.", file: ".."}, - {path: "/..", dir: "/", file: ".."}, - {path: "/../", dir: "/", file: ".."}, - {path: "/../.", dir: "/..", file: "."}, - {path: "/.././", dir: "/..", file: "."}, - {path: "/../..", dir: "/..", file: ".."}, - {path: "/../../", dir: "/..", file: ".."}, - - {path: "", dir: ".", file: "."}, - {path: ".", dir: ".", file: "."}, - {path: "./", dir: ".", file: "."}, - {path: "./.", dir: ".", file: "."}, - {path: "././", dir: ".", file: "."}, - {path: "./..", dir: ".", file: ".."}, - {path: "./../", dir: ".", file: ".."}, - {path: "..", dir: ".", file: ".."}, - {path: "../", dir: ".", file: ".."}, - {path: "../.", dir: "..", file: "."}, - {path: ".././", dir: "..", file: "."}, - {path: "../..", dir: "..", file: ".."}, - {path: "../../", dir: "..", file: ".."}, - - {path: "/foo", dir: "/", file: "foo"}, - {path: "/foo/", dir: "/", file: "foo"}, - {path: "/foo/.", dir: "/foo", file: "."}, - {path: "/foo/./", dir: "/foo", file: "."}, - {path: "/foo/./.", dir: "/foo/.", file: "."}, - {path: "/foo/./..", dir: "/foo/.", file: ".."}, - {path: "/foo/..", dir: "/foo", file: ".."}, - {path: "/foo/../", dir: "/foo", file: ".."}, - {path: "/foo/../.", dir: "/foo/..", file: "."}, - {path: "/foo/../..", dir: "/foo/..", file: ".."}, - - {path: "/foo/bar", dir: "/foo", file: "bar"}, - {path: "/foo/bar/", dir: "/foo", file: "bar"}, - {path: "/foo/bar/.", dir: "/foo/bar", file: "."}, - {path: "/foo/bar/./", dir: "/foo/bar", file: "."}, - {path: "/foo/bar/./.", dir: "/foo/bar/.", file: "."}, - {path: "/foo/bar/./..", dir: "/foo/bar/.", file: ".."}, - {path: "/foo/bar/..", dir: "/foo/bar", file: ".."}, - {path: "/foo/bar/../", dir: "/foo/bar", file: ".."}, - {path: "/foo/bar/../.", dir: "/foo/bar/..", file: "."}, - {path: "/foo/bar/../..", dir: "/foo/bar/..", file: ".."}, - - {path: "foo", dir: ".", file: "foo"}, - {path: "foo", dir: ".", file: "foo"}, - {path: "foo/", dir: ".", file: "foo"}, - {path: "foo/.", dir: "foo", file: "."}, - {path: "foo/./", dir: "foo", file: "."}, - {path: "foo/./.", dir: "foo/.", file: "."}, - {path: "foo/./..", dir: "foo/.", file: ".."}, - {path: "foo/..", dir: "foo", file: ".."}, - {path: "foo/../", dir: "foo", file: ".."}, - {path: "foo/../.", dir: "foo/..", file: "."}, - {path: "foo/../..", dir: "foo/..", file: ".."}, - {path: "foo/", dir: ".", file: "foo"}, - {path: "foo/.", dir: "foo", file: "."}, - - {path: "foo/bar", dir: "foo", file: "bar"}, - {path: "foo/bar/", dir: "foo", file: "bar"}, - {path: "foo/bar/.", dir: "foo/bar", file: "."}, - {path: "foo/bar/./", dir: "foo/bar", file: "."}, - {path: "foo/bar/./.", dir: "foo/bar/.", file: "."}, - {path: "foo/bar/./..", dir: "foo/bar/.", file: ".."}, - {path: "foo/bar/..", dir: "foo/bar", file: ".."}, - {path: "foo/bar/../", dir: "foo/bar", file: ".."}, - {path: "foo/bar/../.", dir: "foo/bar/..", file: "."}, - {path: "foo/bar/../..", dir: "foo/bar/..", file: ".."}, - {path: "foo/bar/", dir: "foo", file: "bar"}, - {path: "foo/bar/.", dir: "foo/bar", file: "."}, - } - - for _, c := range cases { - dir, file := SplitLast(c.path) - if dir != c.dir || file != c.file { - t.Errorf("SplitLast(%q) got (%q, %q), expected (%q, %q)", c.path, dir, file, c.dir, c.file) - } - } -} - -// TestSplitFirst tests variants of path splitting. -func TestSplitFirst(t *testing.T) { - cases := []struct { - path string - first string - remainder string - }{ - {path: "/", first: "/", remainder: ""}, - {path: "/.", first: "/", remainder: "."}, - {path: "///.", first: "/", remainder: "//."}, - {path: "/.///", first: "/", remainder: "."}, - {path: "/./.", first: "/", remainder: "./."}, - {path: "/././", first: "/", remainder: "./."}, - {path: "/./..", first: "/", remainder: "./.."}, - {path: "/./../", first: "/", remainder: "./.."}, - {path: "/..", first: "/", remainder: ".."}, - {path: "/../", first: "/", remainder: ".."}, - {path: "/../.", first: "/", remainder: "../."}, - {path: "/.././", first: "/", remainder: "../."}, - {path: "/../..", first: "/", remainder: "../.."}, - {path: "/../../", first: "/", remainder: "../.."}, - - {path: "", first: ".", remainder: ""}, - {path: ".", first: ".", remainder: ""}, - {path: "./", first: ".", remainder: ""}, - {path: ".///", first: ".", remainder: ""}, - {path: "./.", first: ".", remainder: "."}, - {path: "././", first: ".", remainder: "."}, - {path: "./..", first: ".", remainder: ".."}, - {path: "./../", first: ".", remainder: ".."}, - {path: "..", first: "..", remainder: ""}, - {path: "../", first: "..", remainder: ""}, - {path: "../.", first: "..", remainder: "."}, - {path: ".././", first: "..", remainder: "."}, - {path: "../..", first: "..", remainder: ".."}, - {path: "../../", first: "..", remainder: ".."}, - - {path: "/foo", first: "/", remainder: "foo"}, - {path: "/foo/", first: "/", remainder: "foo"}, - {path: "/foo///", first: "/", remainder: "foo"}, - {path: "/foo/.", first: "/", remainder: "foo/."}, - {path: "/foo/./", first: "/", remainder: "foo/."}, - {path: "/foo/./.", first: "/", remainder: "foo/./."}, - {path: "/foo/./..", first: "/", remainder: "foo/./.."}, - {path: "/foo/..", first: "/", remainder: "foo/.."}, - {path: "/foo/../", first: "/", remainder: "foo/.."}, - {path: "/foo/../.", first: "/", remainder: "foo/../."}, - {path: "/foo/../..", first: "/", remainder: "foo/../.."}, - - {path: "/foo/bar", first: "/", remainder: "foo/bar"}, - {path: "///foo/bar", first: "/", remainder: "//foo/bar"}, - {path: "/foo///bar", first: "/", remainder: "foo///bar"}, - {path: "/foo/bar/.", first: "/", remainder: "foo/bar/."}, - {path: "/foo/bar/./", first: "/", remainder: "foo/bar/."}, - {path: "/foo/bar/./.", first: "/", remainder: "foo/bar/./."}, - {path: "/foo/bar/./..", first: "/", remainder: "foo/bar/./.."}, - {path: "/foo/bar/..", first: "/", remainder: "foo/bar/.."}, - {path: "/foo/bar/../", first: "/", remainder: "foo/bar/.."}, - {path: "/foo/bar/../.", first: "/", remainder: "foo/bar/../."}, - {path: "/foo/bar/../..", first: "/", remainder: "foo/bar/../.."}, - - {path: "foo", first: "foo", remainder: ""}, - {path: "foo", first: "foo", remainder: ""}, - {path: "foo/", first: "foo", remainder: ""}, - {path: "foo///", first: "foo", remainder: ""}, - {path: "foo/.", first: "foo", remainder: "."}, - {path: "foo/./", first: "foo", remainder: "."}, - {path: "foo/./.", first: "foo", remainder: "./."}, - {path: "foo/./..", first: "foo", remainder: "./.."}, - {path: "foo/..", first: "foo", remainder: ".."}, - {path: "foo/../", first: "foo", remainder: ".."}, - {path: "foo/../.", first: "foo", remainder: "../."}, - {path: "foo/../..", first: "foo", remainder: "../.."}, - {path: "foo/", first: "foo", remainder: ""}, - {path: "foo/.", first: "foo", remainder: "."}, - - {path: "foo/bar", first: "foo", remainder: "bar"}, - {path: "foo///bar", first: "foo", remainder: "bar"}, - {path: "foo/bar/", first: "foo", remainder: "bar"}, - {path: "foo/bar/.", first: "foo", remainder: "bar/."}, - {path: "foo/bar/./", first: "foo", remainder: "bar/."}, - {path: "foo/bar/./.", first: "foo", remainder: "bar/./."}, - {path: "foo/bar/./..", first: "foo", remainder: "bar/./.."}, - {path: "foo/bar/..", first: "foo", remainder: "bar/.."}, - {path: "foo/bar/../", first: "foo", remainder: "bar/.."}, - {path: "foo/bar/../.", first: "foo", remainder: "bar/../."}, - {path: "foo/bar/../..", first: "foo", remainder: "bar/../.."}, - {path: "foo/bar/", first: "foo", remainder: "bar"}, - {path: "foo/bar/.", first: "foo", remainder: "bar/."}, - } - - for _, c := range cases { - first, remainder := SplitFirst(c.path) - if first != c.first || remainder != c.remainder { - t.Errorf("SplitFirst(%q) got (%q, %q), expected (%q, %q)", c.path, first, remainder, c.first, c.remainder) - } - } -} - -// TestIsSubpath tests the IsSubpath method. -func TestIsSubpath(t *testing.T) { - tcs := []struct { - // Two absolute paths. - pathA string - pathB string - - // Whether pathA is a subpath of pathB. - wantIsSubpath bool - - // Relative path from pathA to pathB. Only checked if - // wantIsSubpath is true. - wantRelpath string - }{ - { - pathA: "/foo/bar/baz", - pathB: "/foo", - wantIsSubpath: true, - wantRelpath: "bar/baz", - }, - { - pathA: "/foo", - pathB: "/foo/bar/baz", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foobar", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foobar", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foobar", - wantIsSubpath: false, - }, - { - pathA: "/", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/", - wantIsSubpath: true, - wantRelpath: "foo", - }, - { - pathA: "/foo/bar/../bar", - pathB: "/foo", - wantIsSubpath: true, - wantRelpath: "bar", - }, - { - pathA: "/foo/bar", - pathB: "/foo/../foo", - wantIsSubpath: true, - wantRelpath: "bar", - }, - } - - for _, tc := range tcs { - gotRelpath, gotIsSubpath := IsSubpath(tc.pathA, tc.pathB) - if gotRelpath != tc.wantRelpath || gotIsSubpath != tc.wantIsSubpath { - t.Errorf("IsSubpath(%q, %q) got %q %t, want %q %t", tc.pathA, tc.pathB, gotRelpath, gotIsSubpath, tc.wantRelpath, tc.wantIsSubpath) - } - } -} diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD deleted file mode 100644 index bc75ae505..000000000 --- a/pkg/sentry/fs/proc/BUILD +++ /dev/null @@ -1,74 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "proc", - srcs = [ - "cgroup.go", - "cpuinfo.go", - "exec_args.go", - "fds.go", - "filesystems.go", - "fs.go", - "inode.go", - "loadavg.go", - "meminfo.go", - "mounts.go", - "net.go", - "proc.go", - "stat.go", - "sys.go", - "sys_net.go", - "sys_net_state.go", - "task.go", - "uid_gid_map.go", - "uptime.go", - "version.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/proc/device", - "//pkg/sentry/fs/proc/seqfile", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fsbridge", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/tcpip/header", - "//pkg/tcpip/network/ipv4", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "proc_test", - size = "small", - srcs = [ - "net_test.go", - "sys_net_test.go", - ], - library = ":proc", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/inet", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md deleted file mode 100644 index 6667a0916..000000000 --- a/pkg/sentry/fs/proc/README.md +++ /dev/null @@ -1,336 +0,0 @@ -This document tracks what is implemented in procfs. Refer to -Documentation/filesystems/proc.txt in the Linux project for information about -procfs generally. - -**NOTE**: This document is not guaranteed to be up to date. If you find an -inconsistency, please file a bug. - -[TOC] - -## Kernel data - -The following files are implemented: - -<!-- mdformat off(don't wrap the table) --> - -| File /proc/ | Content | -| :------------------------ | :---------------------------------------------------- | -| [cpuinfo](#cpuinfo) | Info about the CPU | -| [filesystems](#filesystems) | Supported filesystems | -| [loadavg](#loadavg) | Load average of last 1, 5 & 15 minutes | -| [meminfo](#meminfo) | Overall memory info | -| [stat](#stat) | Overall kernel statistics | -| [sys](#sys) | Change parameters within the kernel | -| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus | -| [version](#version) | Kernel version | - -<!-- mdformat on --> - -### cpuinfo - -```bash -$ cat /proc/cpuinfo -processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 45 -model name : unknown -stepping : unknown -cpu MHz : 1234.588 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic popcnt tsc_deadline_timer aes xsave avx xsaveopt -bogomips : 1234.59 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management: - -... -``` - -Notable divergences: - -Field name | Notes -:--------------- | :--------------------------------------- -model name | Always unknown -stepping | Always unknown -fpu | Always yes -fpu_exception | Always yes -wp | Always yes -bogomips | Bogus value (matches cpu MHz) -clflush size | Always 64 -cache_alignment | Always 64 -address sizes | Always 46 bits physical, 48 bits virtual -power management | Always blank - -Otherwise fields are derived from the sentry configuration. - -### filesystems - -```bash -$ cat /proc/filesystems -nodev 9p -nodev devpts -nodev devtmpfs -nodev proc -nodev sysfs -nodev tmpfs -``` - -### loadavg - -```bash -$ cat /proc/loadavg -0.00 0.00 0.00 0/0 0 -``` - -Column | Notes -:------------------------------------ | :---------- -CPU.IO utilization in last 1 minute | Always zero -CPU.IO utilization in last 5 minutes | Always zero -CPU.IO utilization in last 10 minutes | Always zero -Num currently running processes | Always zero -Total num processes | Always zero - -TODO(b/62345059): Populate the columns with accurate statistics. - -### meminfo - -```bash -$ cat /proc/meminfo -MemTotal: 2097152 kB -MemFree: 2083540 kB -MemAvailable: 2083540 kB -Buffers: 0 kB -Cached: 4428 kB -SwapCache: 0 kB -Active: 10812 kB -Inactive: 2216 kB -Active(anon): 8600 kB -Inactive(anon): 0 kB -Active(file): 2212 kB -Inactive(file): 2216 kB -Unevictable: 0 kB -Mlocked: 0 kB -SwapTotal: 0 kB -SwapFree: 0 kB -Dirty: 0 kB -Writeback: 0 kB -AnonPages: 8600 kB -Mapped: 4428 kB -Shmem: 0 kB - -``` - -Notable divergences: - -Field name | Notes -:---------------- | :----------------------------------------------------- -Buffers | Always zero, no block devices -SwapCache | Always zero, no swap -Inactive(anon) | Always zero, see SwapCache -Unevictable | Always zero TODO(b/31823263) -Mlocked | Always zero TODO(b/31823263) -SwapTotal | Always zero, no swap -SwapFree | Always zero, no swap -Dirty | Always zero TODO(b/31823263) -Writeback | Always zero TODO(b/31823263) -MemAvailable | Uses the same value as MemFree since there is no swap. -Slab | Missing -SReclaimable | Missing -SUnreclaim | Missing -KernelStack | Missing -PageTables | Missing -NFS_Unstable | Missing -Bounce | Missing -WritebackTmp | Missing -CommitLimit | Missing -Committed_AS | Missing -VmallocTotal | Missing -VmallocUsed | Missing -VmallocChunk | Missing -HardwareCorrupted | Missing -AnonHugePages | Missing -ShmemHugePages | Missing -ShmemPmdMapped | Missing -HugePages_Total | Missing -HugePages_Free | Missing -HugePages_Rsvd | Missing -HugePages_Surp | Missing -Hugepagesize | Missing -DirectMap4k | Missing -DirectMap2M | Missing -DirectMap1G | Missing - -### stat - -```bash -$ cat /proc/stat -cpu 0 0 0 0 0 0 0 0 0 0 -cpu0 0 0 0 0 0 0 0 0 0 0 -cpu1 0 0 0 0 0 0 0 0 0 0 -cpu2 0 0 0 0 0 0 0 0 0 0 -cpu3 0 0 0 0 0 0 0 0 0 0 -cpu4 0 0 0 0 0 0 0 0 0 0 -cpu5 0 0 0 0 0 0 0 0 0 0 -cpu6 0 0 0 0 0 0 0 0 0 0 -cpu7 0 0 0 0 0 0 0 0 0 0 -intr 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -ctxt 0 -btime 1504040968 -processes 0 -procs_running 0 -procs_blokkcked 0 -softirq 0 0 0 0 0 0 0 0 0 0 0 -``` - -All fields except for `btime` are always zero. - -TODO(b/37226836): Populate with accurate fields. - -### sys - -```bash -$ ls /proc/sys -kernel vm -``` - -Directory | Notes -:-------- | :---------------------------- -abi | Missing -debug | Missing -dev | Missing -fs | Missing -kernel | Contains hostname (only) -net | Missing -user | Missing -vm | Contains mmap_min_addr (only) - -### uptime - -```bash -$ cat /proc/uptime -3204.62 0.00 -``` - -Column | Notes -:------------------------------- | :---------------------------- -Total num seconds system running | Time since procfs was mounted -Number of seconds idle | Always zero - -### version - -```bash -$ cat /proc/version -Linux version 4.4 #1 SMP Sun Jan 10 15:06:54 PST 2016 -``` - -## Process-specific data - -The following files are implemented: - -File /proc/PID | Content -:---------------------- | :--------------------------------------------------- -[auxv](#auxv) | Copy of auxiliary vector for the process -[cmdline](#cmdline) | Command line arguments -[comm](#comm) | Command name associated with the process -[environ](#environ) | Process environment -[exe](#exe) | Symlink to the process's executable -[fd](#fd) | Directory containing links to open file descriptors -[fdinfo](#fdinfo) | Information associated with open file descriptors -[gid_map](#gid_map) | Mappings for group IDs inside the user namespace -[io](#io) | IO statistics -[maps](#maps) | Memory mappings (anon, executables, library files) -[mounts](#mounts) | Mounted filesystems -[mountinfo](#mountinfo) | Information about mounts -[ns](#ns) | Directory containing info about supported namespaces -[stat](#stat) | Process statistics -[statm](#statm) | Process memory statistics -[status](#status) | Process status in human readable format -[task](#task) | Directory containing info about running threads -[uid_map](#uid_map) | Mappings for user IDs inside the user namespace - -### auxv - -TODO - -### cmdline - -TODO - -### comm - -TODO - -### environment - -TODO - -### exe - -TODO - -### fd - -TODO - -### fdinfo - -TODO - -### gid_map - -TODO - -### io - -Only has data for rchar, wchar, syscr, and syscw. - -TODO: add more detail. - -### maps - -TODO - -### mounts - -TODO - -### mountinfo - -TODO - -### ns - -TODO - -### stat - -Only has data for pid, comm, state, ppid, utime, stime, cutime, cstime, -num_threads, and exit_signal. - -TODO: add more detail. - -### statm - -Only has data for vss and rss. - -TODO: add more detail. - -### status - -Contains data for Name, State, Tgid, Pid, Ppid, TracerPid, FDSize, VmSize, -VmRSS, Threads, CapInh, CapPrm, CapEff, CapBnd, Seccomp. - -TODO: add more detail. - -### task - -TODO - -### uid_map - -TODO diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD deleted file mode 100644 index 52c9aa93d..000000000 --- a/pkg/sentry/fs/proc/device/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "device", - srcs = ["device.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sentry/device"], -) diff --git a/pkg/sentry/fs/proc/device/device_state_autogen.go b/pkg/sentry/fs/proc/device/device_state_autogen.go new file mode 100644 index 000000000..4a5e3cc88 --- /dev/null +++ b/pkg/sentry/fs/proc/device/device_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package device diff --git a/pkg/sentry/fs/proc/net_test.go b/pkg/sentry/fs/proc/net_test.go deleted file mode 100644 index f18681405..000000000 --- a/pkg/sentry/fs/proc/net_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/inet" -) - -func newIPv6TestStack() *inet.TestStack { - s := inet.NewTestStack() - s.SupportsIPv6Flag = true - return s -} - -func TestIfinet6NoAddresses(t *testing.T) { - n := &ifinet6{s: newIPv6TestStack()} - if got := n.contents(); got != nil { - t.Errorf("Got n.contents() = %v, want = %v", got, nil) - } -} - -func TestIfinet6(t *testing.T) { - s := newIPv6TestStack() - s.InterfacesMap[1] = inet.Interface{Name: "eth0"} - s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"), - }, - } - s.InterfacesMap[2] = inet.Interface{Name: "eth1"} - s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"), - }, - } - want := map[string]struct{}{ - "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {}, - "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {}, - } - - n := &ifinet6{s: s} - contents := n.contents() - if len(contents) != len(want) { - t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want)) - } - got := map[string]struct{}{} - for _, l := range contents { - got[l] = struct{}{} - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Got n.contents() = %v, want = %v", got, want) - } -} diff --git a/pkg/sentry/fs/proc/proc_state_autogen.go b/pkg/sentry/fs/proc/proc_state_autogen.go new file mode 100644 index 000000000..8046c8bf9 --- /dev/null +++ b/pkg/sentry/fs/proc/proc_state_autogen.go @@ -0,0 +1,1835 @@ +// automatically generated by stateify. + +package proc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *execArgInode) StateTypeName() string { + return "pkg/sentry/fs/proc.execArgInode" +} + +func (i *execArgInode) StateFields() []string { + return []string{ + "SimpleFileInode", + "arg", + "t", + } +} + +func (i *execArgInode) beforeSave() {} + +// +checklocksignore +func (i *execArgInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.SimpleFileInode) + stateSinkObject.Save(1, &i.arg) + stateSinkObject.Save(2, &i.t) +} + +func (i *execArgInode) afterLoad() {} + +// +checklocksignore +func (i *execArgInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.SimpleFileInode) + stateSourceObject.Load(1, &i.arg) + stateSourceObject.Load(2, &i.t) +} + +func (f *execArgFile) StateTypeName() string { + return "pkg/sentry/fs/proc.execArgFile" +} + +func (f *execArgFile) StateFields() []string { + return []string{ + "arg", + "t", + } +} + +func (f *execArgFile) beforeSave() {} + +// +checklocksignore +func (f *execArgFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.arg) + stateSinkObject.Save(1, &f.t) +} + +func (f *execArgFile) afterLoad() {} + +// +checklocksignore +func (f *execArgFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.arg) + stateSourceObject.Load(1, &f.t) +} + +func (f *fdDir) StateTypeName() string { + return "pkg/sentry/fs/proc.fdDir" +} + +func (f *fdDir) StateFields() []string { + return []string{ + "Dir", + "t", + } +} + +func (f *fdDir) beforeSave() {} + +// +checklocksignore +func (f *fdDir) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.Dir) + stateSinkObject.Save(1, &f.t) +} + +func (f *fdDir) afterLoad() {} + +// +checklocksignore +func (f *fdDir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.Dir) + stateSourceObject.Load(1, &f.t) +} + +func (f *fdDirFile) StateTypeName() string { + return "pkg/sentry/fs/proc.fdDirFile" +} + +func (f *fdDirFile) StateFields() []string { + return []string{ + "isInfoFile", + "t", + } +} + +func (f *fdDirFile) beforeSave() {} + +// +checklocksignore +func (f *fdDirFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.isInfoFile) + stateSinkObject.Save(1, &f.t) +} + +func (f *fdDirFile) afterLoad() {} + +// +checklocksignore +func (f *fdDirFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.isInfoFile) + stateSourceObject.Load(1, &f.t) +} + +func (fdid *fdInfoDir) StateTypeName() string { + return "pkg/sentry/fs/proc.fdInfoDir" +} + +func (fdid *fdInfoDir) StateFields() []string { + return []string{ + "Dir", + "t", + } +} + +func (fdid *fdInfoDir) beforeSave() {} + +// +checklocksignore +func (fdid *fdInfoDir) StateSave(stateSinkObject state.Sink) { + fdid.beforeSave() + stateSinkObject.Save(0, &fdid.Dir) + stateSinkObject.Save(1, &fdid.t) +} + +func (fdid *fdInfoDir) afterLoad() {} + +// +checklocksignore +func (fdid *fdInfoDir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fdid.Dir) + stateSourceObject.Load(1, &fdid.t) +} + +func (f *filesystemsData) StateTypeName() string { + return "pkg/sentry/fs/proc.filesystemsData" +} + +func (f *filesystemsData) StateFields() []string { + return []string{} +} + +func (f *filesystemsData) beforeSave() {} + +// +checklocksignore +func (f *filesystemsData) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystemsData) afterLoad() {} + +// +checklocksignore +func (f *filesystemsData) StateLoad(stateSourceObject state.Source) { +} + +func (f *filesystem) StateTypeName() string { + return "pkg/sentry/fs/proc.filesystem" +} + +func (f *filesystem) StateFields() []string { + return []string{} +} + +func (f *filesystem) beforeSave() {} + +// +checklocksignore +func (f *filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystem) afterLoad() {} + +// +checklocksignore +func (f *filesystem) StateLoad(stateSourceObject state.Source) { +} + +func (i *taskOwnedInodeOps) StateTypeName() string { + return "pkg/sentry/fs/proc.taskOwnedInodeOps" +} + +func (i *taskOwnedInodeOps) StateFields() []string { + return []string{ + "InodeOperations", + "t", + } +} + +func (i *taskOwnedInodeOps) beforeSave() {} + +// +checklocksignore +func (i *taskOwnedInodeOps) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeOperations) + stateSinkObject.Save(1, &i.t) +} + +func (i *taskOwnedInodeOps) afterLoad() {} + +// +checklocksignore +func (i *taskOwnedInodeOps) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeOperations) + stateSourceObject.Load(1, &i.t) +} + +func (s *staticFileInodeOps) StateTypeName() string { + return "pkg/sentry/fs/proc.staticFileInodeOps" +} + +func (s *staticFileInodeOps) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "InodeStaticFileGetter", + } +} + +func (s *staticFileInodeOps) beforeSave() {} + +// +checklocksignore +func (s *staticFileInodeOps) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeSimpleAttributes) + stateSinkObject.Save(1, &s.InodeStaticFileGetter) +} + +func (s *staticFileInodeOps) afterLoad() {} + +// +checklocksignore +func (s *staticFileInodeOps) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeSimpleAttributes) + stateSourceObject.Load(1, &s.InodeStaticFileGetter) +} + +func (d *loadavgData) StateTypeName() string { + return "pkg/sentry/fs/proc.loadavgData" +} + +func (d *loadavgData) StateFields() []string { + return []string{} +} + +func (d *loadavgData) beforeSave() {} + +// +checklocksignore +func (d *loadavgData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() +} + +func (d *loadavgData) afterLoad() {} + +// +checklocksignore +func (d *loadavgData) StateLoad(stateSourceObject state.Source) { +} + +func (d *meminfoData) StateTypeName() string { + return "pkg/sentry/fs/proc.meminfoData" +} + +func (d *meminfoData) StateFields() []string { + return []string{ + "k", + } +} + +func (d *meminfoData) beforeSave() {} + +// +checklocksignore +func (d *meminfoData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.k) +} + +func (d *meminfoData) afterLoad() {} + +// +checklocksignore +func (d *meminfoData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.k) +} + +func (mif *mountInfoFile) StateTypeName() string { + return "pkg/sentry/fs/proc.mountInfoFile" +} + +func (mif *mountInfoFile) StateFields() []string { + return []string{ + "t", + } +} + +func (mif *mountInfoFile) beforeSave() {} + +// +checklocksignore +func (mif *mountInfoFile) StateSave(stateSinkObject state.Sink) { + mif.beforeSave() + stateSinkObject.Save(0, &mif.t) +} + +func (mif *mountInfoFile) afterLoad() {} + +// +checklocksignore +func (mif *mountInfoFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mif.t) +} + +func (mf *mountsFile) StateTypeName() string { + return "pkg/sentry/fs/proc.mountsFile" +} + +func (mf *mountsFile) StateFields() []string { + return []string{ + "t", + } +} + +func (mf *mountsFile) beforeSave() {} + +// +checklocksignore +func (mf *mountsFile) StateSave(stateSinkObject state.Sink) { + mf.beforeSave() + stateSinkObject.Save(0, &mf.t) +} + +func (mf *mountsFile) afterLoad() {} + +// +checklocksignore +func (mf *mountsFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mf.t) +} + +func (n *ifinet6) StateTypeName() string { + return "pkg/sentry/fs/proc.ifinet6" +} + +func (n *ifinet6) StateFields() []string { + return []string{ + "s", + } +} + +func (n *ifinet6) beforeSave() {} + +// +checklocksignore +func (n *ifinet6) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.s) +} + +func (n *ifinet6) afterLoad() {} + +// +checklocksignore +func (n *ifinet6) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.s) +} + +func (n *netDev) StateTypeName() string { + return "pkg/sentry/fs/proc.netDev" +} + +func (n *netDev) StateFields() []string { + return []string{ + "s", + } +} + +func (n *netDev) beforeSave() {} + +// +checklocksignore +func (n *netDev) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.s) +} + +func (n *netDev) afterLoad() {} + +// +checklocksignore +func (n *netDev) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.s) +} + +func (n *netSnmp) StateTypeName() string { + return "pkg/sentry/fs/proc.netSnmp" +} + +func (n *netSnmp) StateFields() []string { + return []string{ + "s", + } +} + +func (n *netSnmp) beforeSave() {} + +// +checklocksignore +func (n *netSnmp) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.s) +} + +func (n *netSnmp) afterLoad() {} + +// +checklocksignore +func (n *netSnmp) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.s) +} + +func (n *netRoute) StateTypeName() string { + return "pkg/sentry/fs/proc.netRoute" +} + +func (n *netRoute) StateFields() []string { + return []string{ + "s", + } +} + +func (n *netRoute) beforeSave() {} + +// +checklocksignore +func (n *netRoute) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.s) +} + +func (n *netRoute) afterLoad() {} + +// +checklocksignore +func (n *netRoute) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.s) +} + +func (n *netUnix) StateTypeName() string { + return "pkg/sentry/fs/proc.netUnix" +} + +func (n *netUnix) StateFields() []string { + return []string{ + "k", + } +} + +func (n *netUnix) beforeSave() {} + +// +checklocksignore +func (n *netUnix) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.k) +} + +func (n *netUnix) afterLoad() {} + +// +checklocksignore +func (n *netUnix) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.k) +} + +func (n *netTCP) StateTypeName() string { + return "pkg/sentry/fs/proc.netTCP" +} + +func (n *netTCP) StateFields() []string { + return []string{ + "k", + } +} + +func (n *netTCP) beforeSave() {} + +// +checklocksignore +func (n *netTCP) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.k) +} + +func (n *netTCP) afterLoad() {} + +// +checklocksignore +func (n *netTCP) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.k) +} + +func (n *netTCP6) StateTypeName() string { + return "pkg/sentry/fs/proc.netTCP6" +} + +func (n *netTCP6) StateFields() []string { + return []string{ + "k", + } +} + +func (n *netTCP6) beforeSave() {} + +// +checklocksignore +func (n *netTCP6) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.k) +} + +func (n *netTCP6) afterLoad() {} + +// +checklocksignore +func (n *netTCP6) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.k) +} + +func (n *netUDP) StateTypeName() string { + return "pkg/sentry/fs/proc.netUDP" +} + +func (n *netUDP) StateFields() []string { + return []string{ + "k", + } +} + +func (n *netUDP) beforeSave() {} + +// +checklocksignore +func (n *netUDP) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.k) +} + +func (n *netUDP) afterLoad() {} + +// +checklocksignore +func (n *netUDP) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.k) +} + +func (p *proc) StateTypeName() string { + return "pkg/sentry/fs/proc.proc" +} + +func (p *proc) StateFields() []string { + return []string{ + "Dir", + "k", + "pidns", + "cgroupControllers", + } +} + +func (p *proc) beforeSave() {} + +// +checklocksignore +func (p *proc) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.Dir) + stateSinkObject.Save(1, &p.k) + stateSinkObject.Save(2, &p.pidns) + stateSinkObject.Save(3, &p.cgroupControllers) +} + +func (p *proc) afterLoad() {} + +// +checklocksignore +func (p *proc) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.Dir) + stateSourceObject.Load(1, &p.k) + stateSourceObject.Load(2, &p.pidns) + stateSourceObject.Load(3, &p.cgroupControllers) +} + +func (s *self) StateTypeName() string { + return "pkg/sentry/fs/proc.self" +} + +func (s *self) StateFields() []string { + return []string{ + "Symlink", + "pidns", + } +} + +func (s *self) beforeSave() {} + +// +checklocksignore +func (s *self) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Symlink) + stateSinkObject.Save(1, &s.pidns) +} + +func (s *self) afterLoad() {} + +// +checklocksignore +func (s *self) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Symlink) + stateSourceObject.Load(1, &s.pidns) +} + +func (s *threadSelf) StateTypeName() string { + return "pkg/sentry/fs/proc.threadSelf" +} + +func (s *threadSelf) StateFields() []string { + return []string{ + "Symlink", + "pidns", + } +} + +func (s *threadSelf) beforeSave() {} + +// +checklocksignore +func (s *threadSelf) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Symlink) + stateSinkObject.Save(1, &s.pidns) +} + +func (s *threadSelf) afterLoad() {} + +// +checklocksignore +func (s *threadSelf) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Symlink) + stateSourceObject.Load(1, &s.pidns) +} + +func (rpf *rootProcFile) StateTypeName() string { + return "pkg/sentry/fs/proc.rootProcFile" +} + +func (rpf *rootProcFile) StateFields() []string { + return []string{ + "iops", + } +} + +func (rpf *rootProcFile) beforeSave() {} + +// +checklocksignore +func (rpf *rootProcFile) StateSave(stateSinkObject state.Sink) { + rpf.beforeSave() + stateSinkObject.Save(0, &rpf.iops) +} + +func (rpf *rootProcFile) afterLoad() {} + +// +checklocksignore +func (rpf *rootProcFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rpf.iops) +} + +func (s *statData) StateTypeName() string { + return "pkg/sentry/fs/proc.statData" +} + +func (s *statData) StateFields() []string { + return []string{ + "k", + } +} + +func (s *statData) beforeSave() {} + +// +checklocksignore +func (s *statData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.k) +} + +func (s *statData) afterLoad() {} + +// +checklocksignore +func (s *statData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.k) +} + +func (d *mmapMinAddrData) StateTypeName() string { + return "pkg/sentry/fs/proc.mmapMinAddrData" +} + +func (d *mmapMinAddrData) StateFields() []string { + return []string{ + "k", + } +} + +func (d *mmapMinAddrData) beforeSave() {} + +// +checklocksignore +func (d *mmapMinAddrData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.k) +} + +func (d *mmapMinAddrData) afterLoad() {} + +// +checklocksignore +func (d *mmapMinAddrData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.k) +} + +func (o *overcommitMemory) StateTypeName() string { + return "pkg/sentry/fs/proc.overcommitMemory" +} + +func (o *overcommitMemory) StateFields() []string { + return []string{} +} + +func (o *overcommitMemory) beforeSave() {} + +// +checklocksignore +func (o *overcommitMemory) StateSave(stateSinkObject state.Sink) { + o.beforeSave() +} + +func (o *overcommitMemory) afterLoad() {} + +// +checklocksignore +func (o *overcommitMemory) StateLoad(stateSourceObject state.Source) { +} + +func (m *maxMapCount) StateTypeName() string { + return "pkg/sentry/fs/proc.maxMapCount" +} + +func (m *maxMapCount) StateFields() []string { + return []string{} +} + +func (m *maxMapCount) beforeSave() {} + +// +checklocksignore +func (m *maxMapCount) StateSave(stateSinkObject state.Sink) { + m.beforeSave() +} + +func (m *maxMapCount) afterLoad() {} + +// +checklocksignore +func (m *maxMapCount) StateLoad(stateSourceObject state.Source) { +} + +func (h *hostname) StateTypeName() string { + return "pkg/sentry/fs/proc.hostname" +} + +func (h *hostname) StateFields() []string { + return []string{ + "SimpleFileInode", + } +} + +func (h *hostname) beforeSave() {} + +// +checklocksignore +func (h *hostname) StateSave(stateSinkObject state.Sink) { + h.beforeSave() + stateSinkObject.Save(0, &h.SimpleFileInode) +} + +func (h *hostname) afterLoad() {} + +// +checklocksignore +func (h *hostname) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &h.SimpleFileInode) +} + +func (hf *hostnameFile) StateTypeName() string { + return "pkg/sentry/fs/proc.hostnameFile" +} + +func (hf *hostnameFile) StateFields() []string { + return []string{} +} + +func (hf *hostnameFile) beforeSave() {} + +// +checklocksignore +func (hf *hostnameFile) StateSave(stateSinkObject state.Sink) { + hf.beforeSave() +} + +func (hf *hostnameFile) afterLoad() {} + +// +checklocksignore +func (hf *hostnameFile) StateLoad(stateSourceObject state.Source) { +} + +func (t *tcpMemInode) StateTypeName() string { + return "pkg/sentry/fs/proc.tcpMemInode" +} + +func (t *tcpMemInode) StateFields() []string { + return []string{ + "SimpleFileInode", + "dir", + "s", + "size", + } +} + +// +checklocksignore +func (t *tcpMemInode) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.SimpleFileInode) + stateSinkObject.Save(1, &t.dir) + stateSinkObject.Save(2, &t.s) + stateSinkObject.Save(3, &t.size) +} + +// +checklocksignore +func (t *tcpMemInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.SimpleFileInode) + stateSourceObject.Load(1, &t.dir) + stateSourceObject.LoadWait(2, &t.s) + stateSourceObject.Load(3, &t.size) + stateSourceObject.AfterLoad(t.afterLoad) +} + +func (f *tcpMemFile) StateTypeName() string { + return "pkg/sentry/fs/proc.tcpMemFile" +} + +func (f *tcpMemFile) StateFields() []string { + return []string{ + "tcpMemInode", + } +} + +func (f *tcpMemFile) beforeSave() {} + +// +checklocksignore +func (f *tcpMemFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.tcpMemInode) +} + +func (f *tcpMemFile) afterLoad() {} + +// +checklocksignore +func (f *tcpMemFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.tcpMemInode) +} + +func (s *tcpSack) StateTypeName() string { + return "pkg/sentry/fs/proc.tcpSack" +} + +func (s *tcpSack) StateFields() []string { + return []string{ + "SimpleFileInode", + "stack", + "enabled", + } +} + +func (s *tcpSack) beforeSave() {} + +// +checklocksignore +func (s *tcpSack) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SimpleFileInode) + stateSinkObject.Save(1, &s.stack) + stateSinkObject.Save(2, &s.enabled) +} + +// +checklocksignore +func (s *tcpSack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SimpleFileInode) + stateSourceObject.LoadWait(1, &s.stack) + stateSourceObject.Load(2, &s.enabled) + stateSourceObject.AfterLoad(s.afterLoad) +} + +func (f *tcpSackFile) StateTypeName() string { + return "pkg/sentry/fs/proc.tcpSackFile" +} + +func (f *tcpSackFile) StateFields() []string { + return []string{ + "tcpSack", + "stack", + } +} + +func (f *tcpSackFile) beforeSave() {} + +// +checklocksignore +func (f *tcpSackFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.tcpSack) + stateSinkObject.Save(1, &f.stack) +} + +func (f *tcpSackFile) afterLoad() {} + +// +checklocksignore +func (f *tcpSackFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.tcpSack) + stateSourceObject.LoadWait(1, &f.stack) +} + +func (r *tcpRecovery) StateTypeName() string { + return "pkg/sentry/fs/proc.tcpRecovery" +} + +func (r *tcpRecovery) StateFields() []string { + return []string{ + "SimpleFileInode", + "stack", + "recovery", + } +} + +func (r *tcpRecovery) beforeSave() {} + +// +checklocksignore +func (r *tcpRecovery) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.SimpleFileInode) + stateSinkObject.Save(1, &r.stack) + stateSinkObject.Save(2, &r.recovery) +} + +func (r *tcpRecovery) afterLoad() {} + +// +checklocksignore +func (r *tcpRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.SimpleFileInode) + stateSourceObject.LoadWait(1, &r.stack) + stateSourceObject.Load(2, &r.recovery) +} + +func (f *tcpRecoveryFile) StateTypeName() string { + return "pkg/sentry/fs/proc.tcpRecoveryFile" +} + +func (f *tcpRecoveryFile) StateFields() []string { + return []string{ + "tcpRecovery", + "stack", + } +} + +func (f *tcpRecoveryFile) beforeSave() {} + +// +checklocksignore +func (f *tcpRecoveryFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.tcpRecovery) + stateSinkObject.Save(1, &f.stack) +} + +func (f *tcpRecoveryFile) afterLoad() {} + +// +checklocksignore +func (f *tcpRecoveryFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.tcpRecovery) + stateSourceObject.LoadWait(1, &f.stack) +} + +func (ipf *ipForwarding) StateTypeName() string { + return "pkg/sentry/fs/proc.ipForwarding" +} + +func (ipf *ipForwarding) StateFields() []string { + return []string{ + "SimpleFileInode", + "stack", + "enabled", + } +} + +func (ipf *ipForwarding) beforeSave() {} + +// +checklocksignore +func (ipf *ipForwarding) StateSave(stateSinkObject state.Sink) { + ipf.beforeSave() + stateSinkObject.Save(0, &ipf.SimpleFileInode) + stateSinkObject.Save(1, &ipf.stack) + stateSinkObject.Save(2, &ipf.enabled) +} + +// +checklocksignore +func (ipf *ipForwarding) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ipf.SimpleFileInode) + stateSourceObject.LoadWait(1, &ipf.stack) + stateSourceObject.Load(2, &ipf.enabled) + stateSourceObject.AfterLoad(ipf.afterLoad) +} + +func (f *ipForwardingFile) StateTypeName() string { + return "pkg/sentry/fs/proc.ipForwardingFile" +} + +func (f *ipForwardingFile) StateFields() []string { + return []string{ + "ipf", + "stack", + } +} + +func (f *ipForwardingFile) beforeSave() {} + +// +checklocksignore +func (f *ipForwardingFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.ipf) + stateSinkObject.Save(1, &f.stack) +} + +func (f *ipForwardingFile) afterLoad() {} + +// +checklocksignore +func (f *ipForwardingFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.ipf) + stateSourceObject.LoadWait(1, &f.stack) +} + +func (in *portRangeInode) StateTypeName() string { + return "pkg/sentry/fs/proc.portRangeInode" +} + +func (in *portRangeInode) StateFields() []string { + return []string{ + "SimpleFileInode", + "stack", + "start", + "end", + } +} + +func (in *portRangeInode) beforeSave() {} + +// +checklocksignore +func (in *portRangeInode) StateSave(stateSinkObject state.Sink) { + in.beforeSave() + stateSinkObject.Save(0, &in.SimpleFileInode) + stateSinkObject.Save(1, &in.stack) + stateSinkObject.Save(2, &in.start) + stateSinkObject.Save(3, &in.end) +} + +func (in *portRangeInode) afterLoad() {} + +// +checklocksignore +func (in *portRangeInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &in.SimpleFileInode) + stateSourceObject.LoadWait(1, &in.stack) + stateSourceObject.Load(2, &in.start) + stateSourceObject.Load(3, &in.end) +} + +func (pf *portRangeFile) StateTypeName() string { + return "pkg/sentry/fs/proc.portRangeFile" +} + +func (pf *portRangeFile) StateFields() []string { + return []string{ + "inode", + } +} + +func (pf *portRangeFile) beforeSave() {} + +// +checklocksignore +func (pf *portRangeFile) StateSave(stateSinkObject state.Sink) { + pf.beforeSave() + stateSinkObject.Save(0, &pf.inode) +} + +func (pf *portRangeFile) afterLoad() {} + +// +checklocksignore +func (pf *portRangeFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &pf.inode) +} + +func (t *taskDir) StateTypeName() string { + return "pkg/sentry/fs/proc.taskDir" +} + +func (t *taskDir) StateFields() []string { + return []string{ + "Dir", + "t", + } +} + +func (t *taskDir) beforeSave() {} + +// +checklocksignore +func (t *taskDir) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Dir) + stateSinkObject.Save(1, &t.t) +} + +func (t *taskDir) afterLoad() {} + +// +checklocksignore +func (t *taskDir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Dir) + stateSourceObject.Load(1, &t.t) +} + +func (s *subtasks) StateTypeName() string { + return "pkg/sentry/fs/proc.subtasks" +} + +func (s *subtasks) StateFields() []string { + return []string{ + "Dir", + "t", + "p", + } +} + +func (s *subtasks) beforeSave() {} + +// +checklocksignore +func (s *subtasks) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Dir) + stateSinkObject.Save(1, &s.t) + stateSinkObject.Save(2, &s.p) +} + +func (s *subtasks) afterLoad() {} + +// +checklocksignore +func (s *subtasks) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Dir) + stateSourceObject.Load(1, &s.t) + stateSourceObject.Load(2, &s.p) +} + +func (f *subtasksFile) StateTypeName() string { + return "pkg/sentry/fs/proc.subtasksFile" +} + +func (f *subtasksFile) StateFields() []string { + return []string{ + "t", + "pidns", + } +} + +func (f *subtasksFile) beforeSave() {} + +// +checklocksignore +func (f *subtasksFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.t) + stateSinkObject.Save(1, &f.pidns) +} + +func (f *subtasksFile) afterLoad() {} + +// +checklocksignore +func (f *subtasksFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.t) + stateSourceObject.Load(1, &f.pidns) +} + +func (e *exe) StateTypeName() string { + return "pkg/sentry/fs/proc.exe" +} + +func (e *exe) StateFields() []string { + return []string{ + "Symlink", + "t", + } +} + +func (e *exe) beforeSave() {} + +// +checklocksignore +func (e *exe) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Symlink) + stateSinkObject.Save(1, &e.t) +} + +func (e *exe) afterLoad() {} + +// +checklocksignore +func (e *exe) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Symlink) + stateSourceObject.Load(1, &e.t) +} + +func (e *cwd) StateTypeName() string { + return "pkg/sentry/fs/proc.cwd" +} + +func (e *cwd) StateFields() []string { + return []string{ + "Symlink", + "t", + } +} + +func (e *cwd) beforeSave() {} + +// +checklocksignore +func (e *cwd) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Symlink) + stateSinkObject.Save(1, &e.t) +} + +func (e *cwd) afterLoad() {} + +// +checklocksignore +func (e *cwd) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Symlink) + stateSourceObject.Load(1, &e.t) +} + +func (n *namespaceSymlink) StateTypeName() string { + return "pkg/sentry/fs/proc.namespaceSymlink" +} + +func (n *namespaceSymlink) StateFields() []string { + return []string{ + "Symlink", + "t", + } +} + +func (n *namespaceSymlink) beforeSave() {} + +// +checklocksignore +func (n *namespaceSymlink) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.Symlink) + stateSinkObject.Save(1, &n.t) +} + +func (n *namespaceSymlink) afterLoad() {} + +// +checklocksignore +func (n *namespaceSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.Symlink) + stateSourceObject.Load(1, &n.t) +} + +func (m *memData) StateTypeName() string { + return "pkg/sentry/fs/proc.memData" +} + +func (m *memData) StateFields() []string { + return []string{ + "SimpleFileInode", + "t", + } +} + +func (m *memData) beforeSave() {} + +// +checklocksignore +func (m *memData) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.SimpleFileInode) + stateSinkObject.Save(1, &m.t) +} + +func (m *memData) afterLoad() {} + +// +checklocksignore +func (m *memData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.SimpleFileInode) + stateSourceObject.Load(1, &m.t) +} + +func (m *memDataFile) StateTypeName() string { + return "pkg/sentry/fs/proc.memDataFile" +} + +func (m *memDataFile) StateFields() []string { + return []string{ + "t", + } +} + +func (m *memDataFile) beforeSave() {} + +// +checklocksignore +func (m *memDataFile) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.t) +} + +func (m *memDataFile) afterLoad() {} + +// +checklocksignore +func (m *memDataFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.t) +} + +func (md *mapsData) StateTypeName() string { + return "pkg/sentry/fs/proc.mapsData" +} + +func (md *mapsData) StateFields() []string { + return []string{ + "t", + } +} + +func (md *mapsData) beforeSave() {} + +// +checklocksignore +func (md *mapsData) StateSave(stateSinkObject state.Sink) { + md.beforeSave() + stateSinkObject.Save(0, &md.t) +} + +func (md *mapsData) afterLoad() {} + +// +checklocksignore +func (md *mapsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &md.t) +} + +func (sd *smapsData) StateTypeName() string { + return "pkg/sentry/fs/proc.smapsData" +} + +func (sd *smapsData) StateFields() []string { + return []string{ + "t", + } +} + +func (sd *smapsData) beforeSave() {} + +// +checklocksignore +func (sd *smapsData) StateSave(stateSinkObject state.Sink) { + sd.beforeSave() + stateSinkObject.Save(0, &sd.t) +} + +func (sd *smapsData) afterLoad() {} + +// +checklocksignore +func (sd *smapsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sd.t) +} + +func (s *taskStatData) StateTypeName() string { + return "pkg/sentry/fs/proc.taskStatData" +} + +func (s *taskStatData) StateFields() []string { + return []string{ + "t", + "tgstats", + "pidns", + } +} + +func (s *taskStatData) beforeSave() {} + +// +checklocksignore +func (s *taskStatData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.t) + stateSinkObject.Save(1, &s.tgstats) + stateSinkObject.Save(2, &s.pidns) +} + +func (s *taskStatData) afterLoad() {} + +// +checklocksignore +func (s *taskStatData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.t) + stateSourceObject.Load(1, &s.tgstats) + stateSourceObject.Load(2, &s.pidns) +} + +func (s *statmData) StateTypeName() string { + return "pkg/sentry/fs/proc.statmData" +} + +func (s *statmData) StateFields() []string { + return []string{ + "t", + } +} + +func (s *statmData) beforeSave() {} + +// +checklocksignore +func (s *statmData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.t) +} + +func (s *statmData) afterLoad() {} + +// +checklocksignore +func (s *statmData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.t) +} + +func (s *statusData) StateTypeName() string { + return "pkg/sentry/fs/proc.statusData" +} + +func (s *statusData) StateFields() []string { + return []string{ + "t", + "pidns", + } +} + +func (s *statusData) beforeSave() {} + +// +checklocksignore +func (s *statusData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.t) + stateSinkObject.Save(1, &s.pidns) +} + +func (s *statusData) afterLoad() {} + +// +checklocksignore +func (s *statusData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.t) + stateSourceObject.Load(1, &s.pidns) +} + +func (i *ioData) StateTypeName() string { + return "pkg/sentry/fs/proc.ioData" +} + +func (i *ioData) StateFields() []string { + return []string{ + "ioUsage", + } +} + +func (i *ioData) beforeSave() {} + +// +checklocksignore +func (i *ioData) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.ioUsage) +} + +func (i *ioData) afterLoad() {} + +// +checklocksignore +func (i *ioData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.ioUsage) +} + +func (c *comm) StateTypeName() string { + return "pkg/sentry/fs/proc.comm" +} + +func (c *comm) StateFields() []string { + return []string{ + "SimpleFileInode", + "t", + } +} + +func (c *comm) beforeSave() {} + +// +checklocksignore +func (c *comm) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.SimpleFileInode) + stateSinkObject.Save(1, &c.t) +} + +func (c *comm) afterLoad() {} + +// +checklocksignore +func (c *comm) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.SimpleFileInode) + stateSourceObject.Load(1, &c.t) +} + +func (f *commFile) StateTypeName() string { + return "pkg/sentry/fs/proc.commFile" +} + +func (f *commFile) StateFields() []string { + return []string{ + "t", + } +} + +func (f *commFile) beforeSave() {} + +// +checklocksignore +func (f *commFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.t) +} + +func (f *commFile) afterLoad() {} + +// +checklocksignore +func (f *commFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.t) +} + +func (a *auxvec) StateTypeName() string { + return "pkg/sentry/fs/proc.auxvec" +} + +func (a *auxvec) StateFields() []string { + return []string{ + "SimpleFileInode", + "t", + } +} + +func (a *auxvec) beforeSave() {} + +// +checklocksignore +func (a *auxvec) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.SimpleFileInode) + stateSinkObject.Save(1, &a.t) +} + +func (a *auxvec) afterLoad() {} + +// +checklocksignore +func (a *auxvec) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.SimpleFileInode) + stateSourceObject.Load(1, &a.t) +} + +func (f *auxvecFile) StateTypeName() string { + return "pkg/sentry/fs/proc.auxvecFile" +} + +func (f *auxvecFile) StateFields() []string { + return []string{ + "t", + } +} + +func (f *auxvecFile) beforeSave() {} + +// +checklocksignore +func (f *auxvecFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.t) +} + +func (f *auxvecFile) afterLoad() {} + +// +checklocksignore +func (f *auxvecFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.t) +} + +func (o *oomScoreAdj) StateTypeName() string { + return "pkg/sentry/fs/proc.oomScoreAdj" +} + +func (o *oomScoreAdj) StateFields() []string { + return []string{ + "SimpleFileInode", + "t", + } +} + +func (o *oomScoreAdj) beforeSave() {} + +// +checklocksignore +func (o *oomScoreAdj) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.SimpleFileInode) + stateSinkObject.Save(1, &o.t) +} + +func (o *oomScoreAdj) afterLoad() {} + +// +checklocksignore +func (o *oomScoreAdj) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.SimpleFileInode) + stateSourceObject.Load(1, &o.t) +} + +func (f *oomScoreAdjFile) StateTypeName() string { + return "pkg/sentry/fs/proc.oomScoreAdjFile" +} + +func (f *oomScoreAdjFile) StateFields() []string { + return []string{ + "t", + } +} + +func (f *oomScoreAdjFile) beforeSave() {} + +// +checklocksignore +func (f *oomScoreAdjFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.t) +} + +func (f *oomScoreAdjFile) afterLoad() {} + +// +checklocksignore +func (f *oomScoreAdjFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.t) +} + +func (imio *idMapInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/proc.idMapInodeOperations" +} + +func (imio *idMapInodeOperations) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "InodeSimpleExtendedAttributes", + "t", + "gids", + } +} + +func (imio *idMapInodeOperations) beforeSave() {} + +// +checklocksignore +func (imio *idMapInodeOperations) StateSave(stateSinkObject state.Sink) { + imio.beforeSave() + stateSinkObject.Save(0, &imio.InodeSimpleAttributes) + stateSinkObject.Save(1, &imio.InodeSimpleExtendedAttributes) + stateSinkObject.Save(2, &imio.t) + stateSinkObject.Save(3, &imio.gids) +} + +func (imio *idMapInodeOperations) afterLoad() {} + +// +checklocksignore +func (imio *idMapInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &imio.InodeSimpleAttributes) + stateSourceObject.Load(1, &imio.InodeSimpleExtendedAttributes) + stateSourceObject.Load(2, &imio.t) + stateSourceObject.Load(3, &imio.gids) +} + +func (imfo *idMapFileOperations) StateTypeName() string { + return "pkg/sentry/fs/proc.idMapFileOperations" +} + +func (imfo *idMapFileOperations) StateFields() []string { + return []string{ + "iops", + } +} + +func (imfo *idMapFileOperations) beforeSave() {} + +// +checklocksignore +func (imfo *idMapFileOperations) StateSave(stateSinkObject state.Sink) { + imfo.beforeSave() + stateSinkObject.Save(0, &imfo.iops) +} + +func (imfo *idMapFileOperations) afterLoad() {} + +// +checklocksignore +func (imfo *idMapFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &imfo.iops) +} + +func (u *uptime) StateTypeName() string { + return "pkg/sentry/fs/proc.uptime" +} + +func (u *uptime) StateFields() []string { + return []string{ + "SimpleFileInode", + "startTime", + } +} + +func (u *uptime) beforeSave() {} + +// +checklocksignore +func (u *uptime) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.SimpleFileInode) + stateSinkObject.Save(1, &u.startTime) +} + +func (u *uptime) afterLoad() {} + +// +checklocksignore +func (u *uptime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.SimpleFileInode) + stateSourceObject.Load(1, &u.startTime) +} + +func (f *uptimeFile) StateTypeName() string { + return "pkg/sentry/fs/proc.uptimeFile" +} + +func (f *uptimeFile) StateFields() []string { + return []string{ + "startTime", + } +} + +func (f *uptimeFile) beforeSave() {} + +// +checklocksignore +func (f *uptimeFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.startTime) +} + +func (f *uptimeFile) afterLoad() {} + +// +checklocksignore +func (f *uptimeFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.startTime) +} + +func (v *versionData) StateTypeName() string { + return "pkg/sentry/fs/proc.versionData" +} + +func (v *versionData) StateFields() []string { + return []string{ + "k", + } +} + +func (v *versionData) beforeSave() {} + +// +checklocksignore +func (v *versionData) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + stateSinkObject.Save(0, &v.k) +} + +func (v *versionData) afterLoad() {} + +// +checklocksignore +func (v *versionData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.k) +} + +func init() { + state.Register((*execArgInode)(nil)) + state.Register((*execArgFile)(nil)) + state.Register((*fdDir)(nil)) + state.Register((*fdDirFile)(nil)) + state.Register((*fdInfoDir)(nil)) + state.Register((*filesystemsData)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*taskOwnedInodeOps)(nil)) + state.Register((*staticFileInodeOps)(nil)) + state.Register((*loadavgData)(nil)) + state.Register((*meminfoData)(nil)) + state.Register((*mountInfoFile)(nil)) + state.Register((*mountsFile)(nil)) + state.Register((*ifinet6)(nil)) + state.Register((*netDev)(nil)) + state.Register((*netSnmp)(nil)) + state.Register((*netRoute)(nil)) + state.Register((*netUnix)(nil)) + state.Register((*netTCP)(nil)) + state.Register((*netTCP6)(nil)) + state.Register((*netUDP)(nil)) + state.Register((*proc)(nil)) + state.Register((*self)(nil)) + state.Register((*threadSelf)(nil)) + state.Register((*rootProcFile)(nil)) + state.Register((*statData)(nil)) + state.Register((*mmapMinAddrData)(nil)) + state.Register((*overcommitMemory)(nil)) + state.Register((*maxMapCount)(nil)) + state.Register((*hostname)(nil)) + state.Register((*hostnameFile)(nil)) + state.Register((*tcpMemInode)(nil)) + state.Register((*tcpMemFile)(nil)) + state.Register((*tcpSack)(nil)) + state.Register((*tcpSackFile)(nil)) + state.Register((*tcpRecovery)(nil)) + state.Register((*tcpRecoveryFile)(nil)) + state.Register((*ipForwarding)(nil)) + state.Register((*ipForwardingFile)(nil)) + state.Register((*portRangeInode)(nil)) + state.Register((*portRangeFile)(nil)) + state.Register((*taskDir)(nil)) + state.Register((*subtasks)(nil)) + state.Register((*subtasksFile)(nil)) + state.Register((*exe)(nil)) + state.Register((*cwd)(nil)) + state.Register((*namespaceSymlink)(nil)) + state.Register((*memData)(nil)) + state.Register((*memDataFile)(nil)) + state.Register((*mapsData)(nil)) + state.Register((*smapsData)(nil)) + state.Register((*taskStatData)(nil)) + state.Register((*statmData)(nil)) + state.Register((*statusData)(nil)) + state.Register((*ioData)(nil)) + state.Register((*comm)(nil)) + state.Register((*commFile)(nil)) + state.Register((*auxvec)(nil)) + state.Register((*auxvecFile)(nil)) + state.Register((*oomScoreAdj)(nil)) + state.Register((*oomScoreAdjFile)(nil)) + state.Register((*idMapInodeOperations)(nil)) + state.Register((*idMapFileOperations)(nil)) + state.Register((*uptime)(nil)) + state.Register((*uptimeFile)(nil)) + state.Register((*versionData)(nil)) +} diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD deleted file mode 100644 index 90bd32345..000000000 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "seqfile", - srcs = ["seqfile.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/proc/device", - "//pkg/sentry/kernel/time", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "seqfile_test", - size = "small", - srcs = ["seqfile_test.go"], - library = ":seqfile", - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/ramfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go new file mode 100644 index 000000000..1cd9b313b --- /dev/null +++ b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go @@ -0,0 +1,106 @@ +// automatically generated by stateify. + +package seqfile + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *SeqData) StateTypeName() string { + return "pkg/sentry/fs/proc/seqfile.SeqData" +} + +func (s *SeqData) StateFields() []string { + return []string{ + "Buf", + "Handle", + } +} + +func (s *SeqData) beforeSave() {} + +// +checklocksignore +func (s *SeqData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Buf) + stateSinkObject.Save(1, &s.Handle) +} + +func (s *SeqData) afterLoad() {} + +// +checklocksignore +func (s *SeqData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Buf) + stateSourceObject.Load(1, &s.Handle) +} + +func (s *SeqFile) StateTypeName() string { + return "pkg/sentry/fs/proc/seqfile.SeqFile" +} + +func (s *SeqFile) StateFields() []string { + return []string{ + "InodeSimpleExtendedAttributes", + "InodeSimpleAttributes", + "SeqSource", + "source", + "generation", + "lastRead", + } +} + +func (s *SeqFile) beforeSave() {} + +// +checklocksignore +func (s *SeqFile) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeSimpleExtendedAttributes) + stateSinkObject.Save(1, &s.InodeSimpleAttributes) + stateSinkObject.Save(2, &s.SeqSource) + stateSinkObject.Save(3, &s.source) + stateSinkObject.Save(4, &s.generation) + stateSinkObject.Save(5, &s.lastRead) +} + +func (s *SeqFile) afterLoad() {} + +// +checklocksignore +func (s *SeqFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeSimpleExtendedAttributes) + stateSourceObject.Load(1, &s.InodeSimpleAttributes) + stateSourceObject.Load(2, &s.SeqSource) + stateSourceObject.Load(3, &s.source) + stateSourceObject.Load(4, &s.generation) + stateSourceObject.Load(5, &s.lastRead) +} + +func (sfo *seqFileOperations) StateTypeName() string { + return "pkg/sentry/fs/proc/seqfile.seqFileOperations" +} + +func (sfo *seqFileOperations) StateFields() []string { + return []string{ + "seqFile", + } +} + +func (sfo *seqFileOperations) beforeSave() {} + +// +checklocksignore +func (sfo *seqFileOperations) StateSave(stateSinkObject state.Sink) { + sfo.beforeSave() + stateSinkObject.Save(0, &sfo.seqFile) +} + +func (sfo *seqFileOperations) afterLoad() {} + +// +checklocksignore +func (sfo *seqFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sfo.seqFile) +} + +func init() { + state.Register((*SeqData)(nil)) + state.Register((*SeqFile)(nil)) + state.Register((*seqFileOperations)(nil)) +} diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_test.go b/pkg/sentry/fs/proc/seqfile/seqfile_test.go deleted file mode 100644 index 98e394569..000000000 --- a/pkg/sentry/fs/proc/seqfile/seqfile_test.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqfile - -import ( - "bytes" - "fmt" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -type seqTest struct { - actual []SeqData - update bool -} - -func (s *seqTest) Init() { - var sq []SeqData - // Create some SeqData. - for i := 0; i < 10; i++ { - var b []byte - for j := 0; j < 10; j++ { - b = append(b, byte(i)) - } - sq = append(sq, SeqData{ - Buf: b, - Handle: &testHandle{i: i}, - }) - } - s.actual = sq -} - -// NeedsUpdate reports whether we need to update the data we've previously read. -func (s *seqTest) NeedsUpdate(int64) bool { - return s.update -} - -// ReadSeqFiledata returns a slice of SeqData which contains elements -// greater than the handle. -func (s *seqTest) ReadSeqFileData(ctx context.Context, handle SeqHandle) ([]SeqData, int64) { - if handle == nil { - return s.actual, 0 - } - h := *handle.(*testHandle) - var ret []SeqData - for _, b := range s.actual { - // We want the next one. - h2 := *b.Handle.(*testHandle) - if h2.i > h.i { - ret = append(ret, b) - } - } - return ret, 0 -} - -// Flatten a slice of slices into one slice. -func flatten(buf ...[]byte) []byte { - var flat []byte - for _, b := range buf { - flat = append(flat, b...) - } - return flat -} - -type testHandle struct { - i int -} - -type testTable struct { - offset int64 - readBufferSize int - expectedData []byte - expectedError error -} - -func runTableTests(ctx context.Context, table []testTable, dirent *fs.Dirent) error { - for _, tt := range table { - file, err := dirent.Inode.InodeOperations.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - return fmt.Errorf("GetFile returned error: %v", err) - } - - data := make([]byte, tt.readBufferSize) - resultLen, err := file.Preadv(ctx, usermem.BytesIOSequence(data), tt.offset) - if err != tt.expectedError { - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (error) => %v expected %v", tt.readBufferSize, tt.offset, err, tt.expectedError) - } - expectedLen := int64(len(tt.expectedData)) - if resultLen != expectedLen { - // We make this just an error so we wall through and print the data below. - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (size) => %v expected %v", tt.readBufferSize, tt.offset, resultLen, expectedLen) - } - if !bytes.Equal(data[:expectedLen], tt.expectedData) { - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (data) => %v expected %v", tt.readBufferSize, tt.offset, data[:expectedLen], tt.expectedData) - } - } - return nil -} - -func TestSeqFile(t *testing.T) { - testSource := &seqTest{} - testSource.Init() - - // Create a file that can be R/W. - ctx := contexttest.Context(t) - m := fs.NewPseudoMountSource(ctx) - contents := map[string]*fs.Inode{ - "foo": NewSeqFileInode(ctx, testSource, m), - } - root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777)) - - // How about opening it? - inode := fs.NewInode(ctx, root, m, fs.StableAttr{Type: fs.Directory}) - dirent2, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo for n2: %v", err) - } - n2 := dirent2.Inode.InodeOperations - file2, err := n2.GetFile(ctx, dirent2, fs.FileFlags{Read: true, Write: true}) - if err != nil { - t.Fatalf("GetFile returned error: %v", err) - } - - // Writing? - if _, err := file2.Writev(ctx, usermem.BytesIOSequence([]byte("test"))); err == nil { - t.Fatalf("managed to write to n2: %v", err) - } - - // How about reading? - dirent3, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo: %v", err) - } - n3 := dirent3.Inode.InodeOperations - if n2 != n3 { - t.Error("got n2 != n3, want same") - } - - testSource.update = true - - table := []testTable{ - // Read past the end. - {100, 4, []byte{}, io.EOF}, - {110, 4, []byte{}, io.EOF}, - {200, 4, []byte{}, io.EOF}, - // Read a truncated first line. - {0, 4, testSource.actual[0].Buf[:4], nil}, - // Read the whole first line. - {0, 10, testSource.actual[0].Buf, nil}, - // Read the whole first line + 5 bytes of second line. - {0, 15, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:5]), nil}, - // First 4 bytes of the second line. - {10, 4, testSource.actual[1].Buf[:4], nil}, - // Read the two first lines. - {0, 20, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf), nil}, - // Read three lines. - {0, 30, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf), nil}, - // Read everything, but use a bigger buffer than necessary. - {0, 150, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf, testSource.actual[3].Buf, testSource.actual[4].Buf, testSource.actual[5].Buf, testSource.actual[6].Buf, testSource.actual[7].Buf, testSource.actual[8].Buf, testSource.actual[9].Buf), nil}, - // Read the last 3 bytes. - {97, 10, testSource.actual[9].Buf[7:], nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed with testSource.update = %v : %v", testSource.update, err) - } - - // Disable updates and do it again. - testSource.update = false - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed with testSource.update = %v: %v", testSource.update, err) - } -} - -// Test that we behave correctly when the file is updated. -func TestSeqFileFileUpdated(t *testing.T) { - testSource := &seqTest{} - testSource.Init() - testSource.update = true - - // Create a file that can be R/W. - ctx := contexttest.Context(t) - m := fs.NewPseudoMountSource(ctx) - contents := map[string]*fs.Inode{ - "foo": NewSeqFileInode(ctx, testSource, m), - } - root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777)) - - // How about opening it? - inode := fs.NewInode(ctx, root, m, fs.StableAttr{Type: fs.Directory}) - dirent2, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo for dirent2: %v", err) - } - - table := []testTable{ - {0, 16, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:6]), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed: %v", err) - } - // Delete the first entry. - cut := testSource.actual[0].Buf - testSource.actual = testSource.actual[1:] - - table = []testTable{ - // Try reading buffer 0 with an offset. This will not delete the old data. - {1, 5, cut[1:6], nil}, - // Reset our file by reading at offset 0. - {0, 10, testSource.actual[0].Buf, nil}, - {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil}, - // Read the same data a second time. - {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil}, - // Read the following two lines. - {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after removing first entry: %v", err) - } - - // Add a new duplicate line in the middle (6666...) - after := testSource.actual[5:] - testSource.actual = testSource.actual[:4] - // Note the list must be sorted. - testSource.actual = append(testSource.actual, after[0]) - testSource.actual = append(testSource.actual, after...) - - table = []testTable{ - {50, 20, flatten(testSource.actual[4].Buf, testSource.actual[5].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after adding middle entry: %v", err) - } - // This will be used in a later test. - oldTestData := testSource.actual - - // Delete everything. - testSource.actual = testSource.actual[:0] - table = []testTable{ - {20, 20, []byte{}, io.EOF}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after removing all entries: %v", err) - } - // Restore some of the data. - testSource.actual = oldTestData[:1] - table = []testTable{ - {6, 20, testSource.actual[0].Buf[6:], nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after adding first entry back: %v", err) - } - - // Re-extend the data - testSource.actual = oldTestData - table = []testTable{ - {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after extending testSource: %v", err) - } -} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go deleted file mode 100644 index 6ef5738e7..000000000 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/usermem" -) - -func TestQuerySendBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - s.TCPSendBufSize = inet.TCPBufferSize{100, 200, 300} - tmi := &tcpMemInode{s: s, dir: tcpWMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - buf := make([]byte, 100) - dst := usermem.BytesIOSequence(buf) - n, err := tmf.Read(ctx, nil, dst, 0) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if got, want := string(buf[:n]), "100\t200\t300\n"; got != want { - t.Fatalf("Bad string: got %v, want %v", got, want) - } -} - -func TestQueryRecvBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - s.TCPRecvBufSize = inet.TCPBufferSize{100, 200, 300} - tmi := &tcpMemInode{s: s, dir: tcpRMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - buf := make([]byte, 100) - dst := usermem.BytesIOSequence(buf) - n, err := tmf.Read(ctx, nil, dst, 0) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if got, want := string(buf[:n]), "100\t200\t300\n"; got != want { - t.Fatalf("Bad string: got %v, want %v", got, want) - } -} - -var cases = []struct { - str string - initial inet.TCPBufferSize - final inet.TCPBufferSize -}{ - { - str: "", - initial: inet.TCPBufferSize{1, 2, 3}, - final: inet.TCPBufferSize{1, 2, 3}, - }, - { - str: "100\n", - initial: inet.TCPBufferSize{1, 100, 200}, - final: inet.TCPBufferSize{100, 100, 200}, - }, - { - str: "100 200 300\n", - initial: inet.TCPBufferSize{1, 2, 3}, - final: inet.TCPBufferSize{100, 200, 300}, - }, -} - -func TestConfigureSendBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - for _, c := range cases { - s.TCPSendBufSize = c.initial - tmi := &tcpMemInode{s: s, dir: tcpWMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if s.TCPSendBufSize != c.final { - t.Errorf("TCPSendBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPSendBufSize, c.final) - } - } -} - -func TestConfigureRecvBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - for _, c := range cases { - s.TCPRecvBufSize = c.initial - tmi := &tcpMemInode{s: s, dir: tcpRMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if s.TCPRecvBufSize != c.final { - t.Errorf("TCPRecvBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPRecvBufSize, c.final) - } - } -} - -// TestIPForwarding tests the implementation of -// /proc/sys/net/ipv4/ip_forwarding -func TestIPForwarding(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - - var cases = []struct { - comment string - initial bool - str string - final bool - }{ - { - comment: `Forwarding is disabled; write 1 and enable forwarding`, - initial: false, - str: "1", - final: true, - }, - { - comment: `Forwarding is disabled; write 0 and disable forwarding`, - initial: false, - str: "0", - final: false, - }, - { - comment: `Forwarding is enabled; write 1 and enable forwarding`, - initial: true, - str: "1", - final: true, - }, - { - comment: `Forwarding is enabled; write 0 and disable forwarding`, - initial: true, - str: "0", - final: false, - }, - { - comment: `Forwarding is disabled; write 2404 and enable forwarding`, - initial: false, - str: "2404", - final: true, - }, - { - comment: `Forwarding is enabled; write 2404 and enable forwarding`, - initial: true, - str: "2404", - final: true, - }, - } - for _, c := range cases { - t.Run(c.comment, func(t *testing.T) { - s.IPForwarding = c.initial - ipf := &ipForwarding{stack: s} - file := &ipForwardingFile{ - stack: s, - ipf: ipf, - } - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if got, want := s.IPForwarding, c.final; got != want { - t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) - } - - }) - } -} diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD deleted file mode 100644 index bfff010c5..000000000 --- a/pkg/sentry/fs/ramfs/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ramfs", - srcs = [ - "dir.go", - "socket.go", - "symlink.go", - "tree.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/socket/unix/transport", - "//pkg/sync", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "ramfs_test", - size = "small", - srcs = ["tree_test.go"], - library = ":ramfs", - deps = [ - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/ramfs/ramfs_state_autogen.go b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go new file mode 100644 index 000000000..53c836415 --- /dev/null +++ b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go @@ -0,0 +1,182 @@ +// automatically generated by stateify. + +package ramfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (d *Dir) StateTypeName() string { + return "pkg/sentry/fs/ramfs.Dir" +} + +func (d *Dir) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "InodeSimpleExtendedAttributes", + "children", + "dentryMap", + } +} + +func (d *Dir) beforeSave() {} + +// +checklocksignore +func (d *Dir) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.InodeSimpleAttributes) + stateSinkObject.Save(1, &d.InodeSimpleExtendedAttributes) + stateSinkObject.Save(2, &d.children) + stateSinkObject.Save(3, &d.dentryMap) +} + +func (d *Dir) afterLoad() {} + +// +checklocksignore +func (d *Dir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.InodeSimpleAttributes) + stateSourceObject.Load(1, &d.InodeSimpleExtendedAttributes) + stateSourceObject.Load(2, &d.children) + stateSourceObject.Load(3, &d.dentryMap) +} + +func (dfo *dirFileOperations) StateTypeName() string { + return "pkg/sentry/fs/ramfs.dirFileOperations" +} + +func (dfo *dirFileOperations) StateFields() []string { + return []string{ + "dirCursor", + "dir", + } +} + +func (dfo *dirFileOperations) beforeSave() {} + +// +checklocksignore +func (dfo *dirFileOperations) StateSave(stateSinkObject state.Sink) { + dfo.beforeSave() + stateSinkObject.Save(0, &dfo.dirCursor) + stateSinkObject.Save(1, &dfo.dir) +} + +func (dfo *dirFileOperations) afterLoad() {} + +// +checklocksignore +func (dfo *dirFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &dfo.dirCursor) + stateSourceObject.Load(1, &dfo.dir) +} + +func (s *Socket) StateTypeName() string { + return "pkg/sentry/fs/ramfs.Socket" +} + +func (s *Socket) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "InodeSimpleExtendedAttributes", + "ep", + } +} + +func (s *Socket) beforeSave() {} + +// +checklocksignore +func (s *Socket) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeSimpleAttributes) + stateSinkObject.Save(1, &s.InodeSimpleExtendedAttributes) + stateSinkObject.Save(2, &s.ep) +} + +func (s *Socket) afterLoad() {} + +// +checklocksignore +func (s *Socket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeSimpleAttributes) + stateSourceObject.Load(1, &s.InodeSimpleExtendedAttributes) + stateSourceObject.Load(2, &s.ep) +} + +func (s *socketFileOperations) StateTypeName() string { + return "pkg/sentry/fs/ramfs.socketFileOperations" +} + +func (s *socketFileOperations) StateFields() []string { + return []string{} +} + +func (s *socketFileOperations) beforeSave() {} + +// +checklocksignore +func (s *socketFileOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +func (s *socketFileOperations) afterLoad() {} + +// +checklocksignore +func (s *socketFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func (s *Symlink) StateTypeName() string { + return "pkg/sentry/fs/ramfs.Symlink" +} + +func (s *Symlink) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "InodeSimpleExtendedAttributes", + "Target", + } +} + +func (s *Symlink) beforeSave() {} + +// +checklocksignore +func (s *Symlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeSimpleAttributes) + stateSinkObject.Save(1, &s.InodeSimpleExtendedAttributes) + stateSinkObject.Save(2, &s.Target) +} + +func (s *Symlink) afterLoad() {} + +// +checklocksignore +func (s *Symlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeSimpleAttributes) + stateSourceObject.Load(1, &s.InodeSimpleExtendedAttributes) + stateSourceObject.Load(2, &s.Target) +} + +func (s *symlinkFileOperations) StateTypeName() string { + return "pkg/sentry/fs/ramfs.symlinkFileOperations" +} + +func (s *symlinkFileOperations) StateFields() []string { + return []string{} +} + +func (s *symlinkFileOperations) beforeSave() {} + +// +checklocksignore +func (s *symlinkFileOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +func (s *symlinkFileOperations) afterLoad() {} + +// +checklocksignore +func (s *symlinkFileOperations) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*Dir)(nil)) + state.Register((*dirFileOperations)(nil)) + state.Register((*Socket)(nil)) + state.Register((*socketFileOperations)(nil)) + state.Register((*Symlink)(nil)) + state.Register((*symlinkFileOperations)(nil)) +} diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go deleted file mode 100644 index 3e0d1e07e..000000000 --- a/pkg/sentry/fs/ramfs/tree_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ramfs - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -func TestMakeDirectoryTree(t *testing.T) { - - for _, test := range []struct { - name string - subdirs []string - }{ - { - name: "abs paths", - subdirs: []string{ - "/tmp", - "/tmp/a/b", - "/tmp/a/c/d", - "/tmp/c", - "/proc", - "/dev/a/b", - "/tmp", - }, - }, - { - name: "rel paths", - subdirs: []string{ - "tmp", - "tmp/a/b", - "tmp/a/c/d", - "tmp/c", - "proc", - "dev/a/b", - "tmp", - }, - }, - } { - ctx := contexttest.Context(t) - mount := fs.NewPseudoMountSource(ctx) - tree, err := MakeDirectoryTree(ctx, mount, test.subdirs) - if err != nil { - t.Errorf("%s: failed to make ramfs tree, got error %v, want nil", test.name, err) - continue - } - - // Expect to be able to find each of the paths. - mm, err := fs.NewMountNamespace(ctx, tree) - if err != nil { - t.Errorf("%s: failed to create mount manager: %v", test.name, err) - continue - } - root := mm.Root() - defer mm.DecRef(ctx) - - for _, p := range test.subdirs { - maxTraversals := uint(0) - if _, err := mm.FindInode(ctx, root, nil, p, &maxTraversals); err != nil { - t.Errorf("%s: failed to find node %s: %v", test.name, p, err) - break - } - } - } -} diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD deleted file mode 100644 index fdbc5f912..000000000 --- a/pkg/sentry/fs/sys/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sys", - srcs = [ - "device.go", - "devices.go", - "fs.go", - "sys.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/hostarch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/kernel", - ], -) diff --git a/pkg/sentry/fs/sys/sys_state_autogen.go b/pkg/sentry/fs/sys/sys_state_autogen.go new file mode 100644 index 000000000..9c4f22243 --- /dev/null +++ b/pkg/sentry/fs/sys/sys_state_autogen.go @@ -0,0 +1,61 @@ +// automatically generated by stateify. + +package sys + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (c *cpunum) StateTypeName() string { + return "pkg/sentry/fs/sys.cpunum" +} + +func (c *cpunum) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "InodeStaticFileGetter", + } +} + +func (c *cpunum) beforeSave() {} + +// +checklocksignore +func (c *cpunum) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.InodeSimpleAttributes) + stateSinkObject.Save(1, &c.InodeStaticFileGetter) +} + +func (c *cpunum) afterLoad() {} + +// +checklocksignore +func (c *cpunum) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.InodeSimpleAttributes) + stateSourceObject.Load(1, &c.InodeStaticFileGetter) +} + +func (f *filesystem) StateTypeName() string { + return "pkg/sentry/fs/sys.filesystem" +} + +func (f *filesystem) StateFields() []string { + return []string{} +} + +func (f *filesystem) beforeSave() {} + +// +checklocksignore +func (f *filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystem) afterLoad() {} + +// +checklocksignore +func (f *filesystem) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*cpunum)(nil)) + state.Register((*filesystem)(nil)) +} diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD deleted file mode 100644 index e61115932..000000000 --- a/pkg/sentry/fs/timerfd/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "timerfd", - srcs = ["timerfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/time", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/timerfd/timerfd_state_autogen.go b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go new file mode 100644 index 000000000..822eecdd6 --- /dev/null +++ b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package timerfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *TimerOperations) StateTypeName() string { + return "pkg/sentry/fs/timerfd.TimerOperations" +} + +func (t *TimerOperations) StateFields() []string { + return []string{ + "timer", + "val", + } +} + +func (t *TimerOperations) beforeSave() {} + +// +checklocksignore +func (t *TimerOperations) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + if !state.IsZeroValue(&t.events) { + state.Failf("events is %#v, expected zero", &t.events) + } + stateSinkObject.Save(0, &t.timer) + stateSinkObject.Save(1, &t.val) +} + +func (t *TimerOperations) afterLoad() {} + +// +checklocksignore +func (t *TimerOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.timer) + stateSourceObject.Load(1, &t.val) +} + +func init() { + state.Register((*TimerOperations)(nil)) +} diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD deleted file mode 100644 index 511fffb43..000000000 --- a/pkg/sentry/fs/tmpfs/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmpfs", - srcs = [ - "device.go", - "file_regular.go", - "fs.go", - "inode_file.go", - "tmpfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/safemem", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fsmetric", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "tmpfs_test", - size = "small", - srcs = ["file_test.go"], - library = ":tmpfs", - deps = [ - "//pkg/context", - "//pkg/hostarch", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/usage", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go deleted file mode 100644 index 1718f9372..000000000 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/usermem" -) - -func newFileInode(ctx context.Context) *fs.Inode { - m := fs.NewCachingMountSource(ctx, &Filesystem{}, fs.MountSourceFlags{}) - iops := NewInMemoryFile(ctx, usage.Tmpfs, fs.WithCurrentTime(ctx, fs.UnstableAttr{})) - return fs.NewInode(ctx, iops, m, fs.StableAttr{ - DeviceID: tmpfsDevice.DeviceID(), - InodeID: tmpfsDevice.NextIno(), - BlockSize: hostarch.PageSize, - Type: fs.RegularFile, - }) -} - -func newFile(ctx context.Context) *fs.File { - inode := newFileInode(ctx) - f, _ := inode.GetFile(ctx, fs.NewDirent(ctx, inode, "stub"), fs.FileFlags{Read: true, Write: true}) - return f -} - -// Allocate once, write twice. -func TestGrow(t *testing.T) { - ctx := contexttest.Context(t) - f := newFile(ctx) - defer f.DecRef(ctx) - - abuf := bytes.Repeat([]byte{'a'}, 68) - n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0) - if n != int64(len(abuf)) || err != nil { - t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(abuf)) - } - - bbuf := bytes.Repeat([]byte{'b'}, 856) - n, err = f.Pwritev(ctx, usermem.BytesIOSequence(bbuf), 68) - if n != int64(len(bbuf)) || err != nil { - t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(bbuf)) - } - - rbuf := make([]byte, len(abuf)+len(bbuf)) - n, err = f.Preadv(ctx, usermem.BytesIOSequence(rbuf), 0) - if n != int64(len(rbuf)) || err != nil { - t.Fatalf("Preadv got (%d, %v) want (%d, nil)", n, err, len(rbuf)) - } - - if want := append(abuf, bbuf...); !bytes.Equal(rbuf, want) { - t.Fatalf("Read %v, want %v", rbuf, want) - } -} diff --git a/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go new file mode 100644 index 000000000..ab5f75fea --- /dev/null +++ b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go @@ -0,0 +1,211 @@ +// automatically generated by stateify. + +package tmpfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *regularFileOperations) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.regularFileOperations" +} + +func (r *regularFileOperations) StateFields() []string { + return []string{ + "iops", + } +} + +func (r *regularFileOperations) beforeSave() {} + +// +checklocksignore +func (r *regularFileOperations) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.iops) +} + +func (r *regularFileOperations) afterLoad() {} + +// +checklocksignore +func (r *regularFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.iops) +} + +func (f *Filesystem) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.Filesystem" +} + +func (f *Filesystem) StateFields() []string { + return []string{} +} + +func (f *Filesystem) beforeSave() {} + +// +checklocksignore +func (f *Filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *Filesystem) afterLoad() {} + +// +checklocksignore +func (f *Filesystem) StateLoad(stateSourceObject state.Source) { +} + +func (f *fileInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.fileInodeOperations" +} + +func (f *fileInodeOperations) StateFields() []string { + return []string{ + "InodeSimpleExtendedAttributes", + "kernel", + "memUsage", + "attr", + "mappings", + "writableMappingPages", + "data", + "seals", + } +} + +func (f *fileInodeOperations) beforeSave() {} + +// +checklocksignore +func (f *fileInodeOperations) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.InodeSimpleExtendedAttributes) + stateSinkObject.Save(1, &f.kernel) + stateSinkObject.Save(2, &f.memUsage) + stateSinkObject.Save(3, &f.attr) + stateSinkObject.Save(4, &f.mappings) + stateSinkObject.Save(5, &f.writableMappingPages) + stateSinkObject.Save(6, &f.data) + stateSinkObject.Save(7, &f.seals) +} + +func (f *fileInodeOperations) afterLoad() {} + +// +checklocksignore +func (f *fileInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.InodeSimpleExtendedAttributes) + stateSourceObject.Load(1, &f.kernel) + stateSourceObject.Load(2, &f.memUsage) + stateSourceObject.Load(3, &f.attr) + stateSourceObject.Load(4, &f.mappings) + stateSourceObject.Load(5, &f.writableMappingPages) + stateSourceObject.Load(6, &f.data) + stateSourceObject.Load(7, &f.seals) +} + +func (d *Dir) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.Dir" +} + +func (d *Dir) StateFields() []string { + return []string{ + "ramfsDir", + "kernel", + } +} + +func (d *Dir) beforeSave() {} + +// +checklocksignore +func (d *Dir) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.ramfsDir) + stateSinkObject.Save(1, &d.kernel) +} + +// +checklocksignore +func (d *Dir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.ramfsDir) + stateSourceObject.Load(1, &d.kernel) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (s *Symlink) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.Symlink" +} + +func (s *Symlink) StateFields() []string { + return []string{ + "Symlink", + } +} + +func (s *Symlink) beforeSave() {} + +// +checklocksignore +func (s *Symlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Symlink) +} + +func (s *Symlink) afterLoad() {} + +// +checklocksignore +func (s *Symlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Symlink) +} + +func (s *Socket) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.Socket" +} + +func (s *Socket) StateFields() []string { + return []string{ + "Socket", + } +} + +func (s *Socket) beforeSave() {} + +// +checklocksignore +func (s *Socket) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Socket) +} + +func (s *Socket) afterLoad() {} + +// +checklocksignore +func (s *Socket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Socket) +} + +func (f *Fifo) StateTypeName() string { + return "pkg/sentry/fs/tmpfs.Fifo" +} + +func (f *Fifo) StateFields() []string { + return []string{ + "InodeOperations", + } +} + +func (f *Fifo) beforeSave() {} + +// +checklocksignore +func (f *Fifo) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.InodeOperations) +} + +func (f *Fifo) afterLoad() {} + +// +checklocksignore +func (f *Fifo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.InodeOperations) +} + +func init() { + state.Register((*regularFileOperations)(nil)) + state.Register((*Filesystem)(nil)) + state.Register((*fileInodeOperations)(nil)) + state.Register((*Dir)(nil)) + state.Register((*Symlink)(nil)) + state.Register((*Socket)(nil)) + state.Register((*Fifo)(nil)) +} diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD deleted file mode 100644 index 9e9dc06f3..000000000 --- a/pkg/sentry/fs/tty/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tty", - srcs = [ - "dir.go", - "fs.go", - "line_discipline.go", - "master.go", - "queue.go", - "replica.go", - "terminal.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal/primitive", - "//pkg/refs", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "tty_test", - size = "small", - srcs = ["tty_test.go"], - library = ":tty", - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/contexttest", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/tty/tty_state_autogen.go b/pkg/sentry/fs/tty/tty_state_autogen.go new file mode 100644 index 000000000..2fb0a8d27 --- /dev/null +++ b/pkg/sentry/fs/tty/tty_state_autogen.go @@ -0,0 +1,407 @@ +// automatically generated by stateify. + +package tty + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (d *dirInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.dirInodeOperations" +} + +func (d *dirInodeOperations) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "msrc", + "master", + "replicas", + "dentryMap", + "next", + } +} + +func (d *dirInodeOperations) beforeSave() {} + +// +checklocksignore +func (d *dirInodeOperations) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.InodeSimpleAttributes) + stateSinkObject.Save(1, &d.msrc) + stateSinkObject.Save(2, &d.master) + stateSinkObject.Save(3, &d.replicas) + stateSinkObject.Save(4, &d.dentryMap) + stateSinkObject.Save(5, &d.next) +} + +func (d *dirInodeOperations) afterLoad() {} + +// +checklocksignore +func (d *dirInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.InodeSimpleAttributes) + stateSourceObject.Load(1, &d.msrc) + stateSourceObject.Load(2, &d.master) + stateSourceObject.Load(3, &d.replicas) + stateSourceObject.Load(4, &d.dentryMap) + stateSourceObject.Load(5, &d.next) +} + +func (df *dirFileOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.dirFileOperations" +} + +func (df *dirFileOperations) StateFields() []string { + return []string{ + "di", + "dirCursor", + } +} + +func (df *dirFileOperations) beforeSave() {} + +// +checklocksignore +func (df *dirFileOperations) StateSave(stateSinkObject state.Sink) { + df.beforeSave() + stateSinkObject.Save(0, &df.di) + stateSinkObject.Save(1, &df.dirCursor) +} + +func (df *dirFileOperations) afterLoad() {} + +// +checklocksignore +func (df *dirFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &df.di) + stateSourceObject.Load(1, &df.dirCursor) +} + +func (f *filesystem) StateTypeName() string { + return "pkg/sentry/fs/tty.filesystem" +} + +func (f *filesystem) StateFields() []string { + return []string{} +} + +func (f *filesystem) beforeSave() {} + +// +checklocksignore +func (f *filesystem) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystem) afterLoad() {} + +// +checklocksignore +func (f *filesystem) StateLoad(stateSourceObject state.Source) { +} + +func (s *superOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.superOperations" +} + +func (s *superOperations) StateFields() []string { + return []string{} +} + +func (s *superOperations) beforeSave() {} + +// +checklocksignore +func (s *superOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +func (s *superOperations) afterLoad() {} + +// +checklocksignore +func (s *superOperations) StateLoad(stateSourceObject state.Source) { +} + +func (l *lineDiscipline) StateTypeName() string { + return "pkg/sentry/fs/tty.lineDiscipline" +} + +func (l *lineDiscipline) StateFields() []string { + return []string{ + "size", + "inQueue", + "outQueue", + "termios", + "column", + } +} + +func (l *lineDiscipline) beforeSave() {} + +// +checklocksignore +func (l *lineDiscipline) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + if !state.IsZeroValue(&l.masterWaiter) { + state.Failf("masterWaiter is %#v, expected zero", &l.masterWaiter) + } + if !state.IsZeroValue(&l.replicaWaiter) { + state.Failf("replicaWaiter is %#v, expected zero", &l.replicaWaiter) + } + stateSinkObject.Save(0, &l.size) + stateSinkObject.Save(1, &l.inQueue) + stateSinkObject.Save(2, &l.outQueue) + stateSinkObject.Save(3, &l.termios) + stateSinkObject.Save(4, &l.column) +} + +func (l *lineDiscipline) afterLoad() {} + +// +checklocksignore +func (l *lineDiscipline) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.size) + stateSourceObject.Load(1, &l.inQueue) + stateSourceObject.Load(2, &l.outQueue) + stateSourceObject.Load(3, &l.termios) + stateSourceObject.Load(4, &l.column) +} + +func (o *outputQueueTransformer) StateTypeName() string { + return "pkg/sentry/fs/tty.outputQueueTransformer" +} + +func (o *outputQueueTransformer) StateFields() []string { + return []string{} +} + +func (o *outputQueueTransformer) beforeSave() {} + +// +checklocksignore +func (o *outputQueueTransformer) StateSave(stateSinkObject state.Sink) { + o.beforeSave() +} + +func (o *outputQueueTransformer) afterLoad() {} + +// +checklocksignore +func (o *outputQueueTransformer) StateLoad(stateSourceObject state.Source) { +} + +func (i *inputQueueTransformer) StateTypeName() string { + return "pkg/sentry/fs/tty.inputQueueTransformer" +} + +func (i *inputQueueTransformer) StateFields() []string { + return []string{} +} + +func (i *inputQueueTransformer) beforeSave() {} + +// +checklocksignore +func (i *inputQueueTransformer) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *inputQueueTransformer) afterLoad() {} + +// +checklocksignore +func (i *inputQueueTransformer) StateLoad(stateSourceObject state.Source) { +} + +func (mi *masterInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.masterInodeOperations" +} + +func (mi *masterInodeOperations) StateFields() []string { + return []string{ + "SimpleFileInode", + "d", + } +} + +func (mi *masterInodeOperations) beforeSave() {} + +// +checklocksignore +func (mi *masterInodeOperations) StateSave(stateSinkObject state.Sink) { + mi.beforeSave() + stateSinkObject.Save(0, &mi.SimpleFileInode) + stateSinkObject.Save(1, &mi.d) +} + +func (mi *masterInodeOperations) afterLoad() {} + +// +checklocksignore +func (mi *masterInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mi.SimpleFileInode) + stateSourceObject.Load(1, &mi.d) +} + +func (mf *masterFileOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.masterFileOperations" +} + +func (mf *masterFileOperations) StateFields() []string { + return []string{ + "d", + "t", + } +} + +func (mf *masterFileOperations) beforeSave() {} + +// +checklocksignore +func (mf *masterFileOperations) StateSave(stateSinkObject state.Sink) { + mf.beforeSave() + stateSinkObject.Save(0, &mf.d) + stateSinkObject.Save(1, &mf.t) +} + +func (mf *masterFileOperations) afterLoad() {} + +// +checklocksignore +func (mf *masterFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mf.d) + stateSourceObject.Load(1, &mf.t) +} + +func (q *queue) StateTypeName() string { + return "pkg/sentry/fs/tty.queue" +} + +func (q *queue) StateFields() []string { + return []string{ + "readBuf", + "waitBuf", + "waitBufLen", + "readable", + "transformer", + } +} + +func (q *queue) beforeSave() {} + +// +checklocksignore +func (q *queue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.readBuf) + stateSinkObject.Save(1, &q.waitBuf) + stateSinkObject.Save(2, &q.waitBufLen) + stateSinkObject.Save(3, &q.readable) + stateSinkObject.Save(4, &q.transformer) +} + +func (q *queue) afterLoad() {} + +// +checklocksignore +func (q *queue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.readBuf) + stateSourceObject.Load(1, &q.waitBuf) + stateSourceObject.Load(2, &q.waitBufLen) + stateSourceObject.Load(3, &q.readable) + stateSourceObject.Load(4, &q.transformer) +} + +func (si *replicaInodeOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.replicaInodeOperations" +} + +func (si *replicaInodeOperations) StateFields() []string { + return []string{ + "SimpleFileInode", + "d", + "t", + } +} + +func (si *replicaInodeOperations) beforeSave() {} + +// +checklocksignore +func (si *replicaInodeOperations) StateSave(stateSinkObject state.Sink) { + si.beforeSave() + stateSinkObject.Save(0, &si.SimpleFileInode) + stateSinkObject.Save(1, &si.d) + stateSinkObject.Save(2, &si.t) +} + +func (si *replicaInodeOperations) afterLoad() {} + +// +checklocksignore +func (si *replicaInodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &si.SimpleFileInode) + stateSourceObject.Load(1, &si.d) + stateSourceObject.Load(2, &si.t) +} + +func (sf *replicaFileOperations) StateTypeName() string { + return "pkg/sentry/fs/tty.replicaFileOperations" +} + +func (sf *replicaFileOperations) StateFields() []string { + return []string{ + "si", + } +} + +func (sf *replicaFileOperations) beforeSave() {} + +// +checklocksignore +func (sf *replicaFileOperations) StateSave(stateSinkObject state.Sink) { + sf.beforeSave() + stateSinkObject.Save(0, &sf.si) +} + +func (sf *replicaFileOperations) afterLoad() {} + +// +checklocksignore +func (sf *replicaFileOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sf.si) +} + +func (tm *Terminal) StateTypeName() string { + return "pkg/sentry/fs/tty.Terminal" +} + +func (tm *Terminal) StateFields() []string { + return []string{ + "AtomicRefCount", + "n", + "d", + "ld", + "masterKTTY", + "replicaKTTY", + } +} + +func (tm *Terminal) beforeSave() {} + +// +checklocksignore +func (tm *Terminal) StateSave(stateSinkObject state.Sink) { + tm.beforeSave() + stateSinkObject.Save(0, &tm.AtomicRefCount) + stateSinkObject.Save(1, &tm.n) + stateSinkObject.Save(2, &tm.d) + stateSinkObject.Save(3, &tm.ld) + stateSinkObject.Save(4, &tm.masterKTTY) + stateSinkObject.Save(5, &tm.replicaKTTY) +} + +func (tm *Terminal) afterLoad() {} + +// +checklocksignore +func (tm *Terminal) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tm.AtomicRefCount) + stateSourceObject.Load(1, &tm.n) + stateSourceObject.Load(2, &tm.d) + stateSourceObject.Load(3, &tm.ld) + stateSourceObject.Load(4, &tm.masterKTTY) + stateSourceObject.Load(5, &tm.replicaKTTY) +} + +func init() { + state.Register((*dirInodeOperations)(nil)) + state.Register((*dirFileOperations)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*superOperations)(nil)) + state.Register((*lineDiscipline)(nil)) + state.Register((*outputQueueTransformer)(nil)) + state.Register((*inputQueueTransformer)(nil)) + state.Register((*masterInodeOperations)(nil)) + state.Register((*masterFileOperations)(nil)) + state.Register((*queue)(nil)) + state.Register((*replicaInodeOperations)(nil)) + state.Register((*replicaFileOperations)(nil)) + state.Register((*Terminal)(nil)) +} diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go deleted file mode 100644 index 49edee83d..000000000 --- a/pkg/sentry/fs/tty/tty_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tty - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" -) - -func TestSimpleMasterToReplica(t *testing.T) { - ld := newLineDiscipline(linux.DefaultReplicaTermios) - ctx := contexttest.Context(t) - inBytes := []byte("hello, tty\n") - src := usermem.BytesIOSequence(inBytes) - outBytes := make([]byte, 32) - dst := usermem.BytesIOSequence(outBytes) - - // Write to the input queue. - nw, err := ld.inputQueueWrite(ctx, src) - if err != nil { - t.Fatalf("error writing to input queue: %v", err) - } - if nw != int64(len(inBytes)) { - t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes)) - } - - // Read from the input queue. - nr, err := ld.inputQueueRead(ctx, dst) - if err != nil { - t.Fatalf("error reading from input queue: %v", err) - } - if nr != int64(len(inBytes)) { - t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes)) - } - - outStr := string(outBytes[:nr]) - inStr := string(inBytes) - if outStr != inStr { - t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr) - } -} diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD deleted file mode 100644 index 23b5508fd..000000000 --- a/pkg/sentry/fs/user/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "user", - srcs = [ - "path.go", - "user.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) - -go_test( - name = "user_test", - size = "small", - srcs = ["user_test.go"], - library = ":user", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fs/user/user_state_autogen.go b/pkg/sentry/fs/user/user_state_autogen.go new file mode 100644 index 000000000..8083e036c --- /dev/null +++ b/pkg/sentry/fs/user/user_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package user diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go deleted file mode 100644 index 7f8fa8038..000000000 --- a/pkg/sentry/fs/user/user_test.go +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package user - -import ( - "fmt" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/usermem" -) - -// createEtcPasswd creates /etc/passwd with the given contents and mode. If -// mode is empty, then no file will be created. If mode is not a regular file -// mode, then contents is ignored. -func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode linux.FileMode) error { - if err := root.CreateDirectory(ctx, root, "etc", fs.FilePermsFromMode(0755)); err != nil { - return err - } - etc, err := root.Walk(ctx, root, "etc") - if err != nil { - return err - } - defer etc.DecRef(ctx) - switch mode.FileType() { - case 0: - // Don't create anything. - return nil - case linux.S_IFREG: - passwd, err := etc.Create(ctx, root, "passwd", fs.FileFlags{Write: true}, fs.FilePermsFromMode(mode)) - if err != nil { - return err - } - defer passwd.DecRef(ctx) - if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil { - return err - } - return nil - case linux.S_IFDIR: - return etc.CreateDirectory(ctx, root, "passwd", fs.FilePermsFromMode(mode)) - case linux.S_IFIFO: - return etc.CreateFifo(ctx, root, "passwd", fs.FilePermsFromMode(mode)) - default: - return fmt.Errorf("unknown file type %x", mode.FileType()) - } -} - -// TestGetExecUserHome tests the getExecUserHome function. -func TestGetExecUserHome(t *testing.T) { - tests := map[string]struct { - uid auth.KUID - passwdContents string - passwdMode linux.FileMode - expected string - }{ - "success": { - uid: 1000, - passwdContents: "adin::1000:1111::/home/adin:/bin/sh", - passwdMode: linux.S_IFREG | 0666, - expected: "/home/adin", - }, - "no_perms": { - uid: 1000, - passwdContents: "adin::1000:1111::/home/adin:/bin/sh", - passwdMode: linux.S_IFREG, - expected: "/", - }, - "no_passwd": { - uid: 1000, - expected: "/", - }, - "directory": { - uid: 1000, - passwdMode: linux.S_IFDIR | 0666, - expected: "/", - }, - // Currently we don't allow named pipes. - "named_pipe": { - uid: 1000, - passwdMode: linux.S_IFIFO | 0666, - expected: "/", - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ctx := contexttest.Context(t) - msrc := fs.NewPseudoMountSource(ctx) - rootInode, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */) - if err != nil { - t.Fatalf("tmpfs.NewDir failed: %v", err) - } - - mns, err := fs.NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - defer mns.DecRef(ctx) - root := mns.Root() - defer root.DecRef(ctx) - ctx = fs.WithRoot(ctx, root) - - if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil { - t.Fatalf("createEtcPasswd failed: %v", err) - } - - got, err := getExecUserHome(ctx, mns, tc.uid) - if err != nil { - t.Fatalf("failed to get user home: %v", err) - } - - if got != tc.expected { - t.Fatalf("expected %v, got: %v", tc.expected, got) - } - }) - } -} - -// TestFindHomeInPasswd tests the findHomeInPasswd function's passwd file parsing. -func TestFindHomeInPasswd(t *testing.T) { - tests := map[string]struct { - uid uint32 - passwd string - expected string - def string - }{ - "empty": { - uid: 1000, - passwd: "", - expected: "/", - def: "/", - }, - "whitespace": { - uid: 1000, - passwd: " ", - expected: "/", - def: "/", - }, - "full": { - uid: 1000, - passwd: "adin::1000:1111::/home/adin:/bin/sh", - expected: "/home/adin", - def: "/", - }, - // For better or worse, this is how runc works. - "partial": { - uid: 1000, - passwd: "adin::1000:1111:", - expected: "", - def: "/", - }, - "multiple": { - uid: 1001, - passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1001:1111::/home/ian:/bin/sh", - expected: "/home/ian", - def: "/", - }, - "duplicate": { - uid: 1000, - passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1000:1111::/home/ian:/bin/sh", - expected: "/home/adin", - def: "/", - }, - "empty_lines": { - uid: 1001, - passwd: "adin::1000:1111::/home/adin:/bin/sh\n\n\nian::1001:1111::/home/ian:/bin/sh", - expected: "/home/ian", - def: "/", - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - got, err := findHomeInPasswd(tc.uid, strings.NewReader(tc.passwd), tc.def) - if err != nil { - t.Fatalf("error parsing passwd: %v", err) - } - if tc.expected != got { - t.Fatalf("expected %v, got: %v", tc.expected, got) - } - }) - } -} diff --git a/pkg/sentry/fsbridge/BUILD b/pkg/sentry/fsbridge/BUILD deleted file mode 100644 index 4631db2bb..000000000 --- a/pkg/sentry/fsbridge/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "fsbridge", - srcs = [ - "bridge.go", - "fs.go", - "vfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsbridge/fsbridge_state_autogen.go b/pkg/sentry/fsbridge/fsbridge_state_autogen.go new file mode 100644 index 000000000..b4b240bfa --- /dev/null +++ b/pkg/sentry/fsbridge/fsbridge_state_autogen.go @@ -0,0 +1,126 @@ +// automatically generated by stateify. + +package fsbridge + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *fsFile) StateTypeName() string { + return "pkg/sentry/fsbridge.fsFile" +} + +func (f *fsFile) StateFields() []string { + return []string{ + "file", + } +} + +func (f *fsFile) beforeSave() {} + +// +checklocksignore +func (f *fsFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.file) +} + +func (f *fsFile) afterLoad() {} + +// +checklocksignore +func (f *fsFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.file) +} + +func (l *fsLookup) StateTypeName() string { + return "pkg/sentry/fsbridge.fsLookup" +} + +func (l *fsLookup) StateFields() []string { + return []string{ + "mntns", + "root", + "workingDir", + } +} + +func (l *fsLookup) beforeSave() {} + +// +checklocksignore +func (l *fsLookup) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.mntns) + stateSinkObject.Save(1, &l.root) + stateSinkObject.Save(2, &l.workingDir) +} + +func (l *fsLookup) afterLoad() {} + +// +checklocksignore +func (l *fsLookup) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.mntns) + stateSourceObject.Load(1, &l.root) + stateSourceObject.Load(2, &l.workingDir) +} + +func (f *VFSFile) StateTypeName() string { + return "pkg/sentry/fsbridge.VFSFile" +} + +func (f *VFSFile) StateFields() []string { + return []string{ + "file", + } +} + +func (f *VFSFile) beforeSave() {} + +// +checklocksignore +func (f *VFSFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.file) +} + +func (f *VFSFile) afterLoad() {} + +// +checklocksignore +func (f *VFSFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.file) +} + +func (l *vfsLookup) StateTypeName() string { + return "pkg/sentry/fsbridge.vfsLookup" +} + +func (l *vfsLookup) StateFields() []string { + return []string{ + "mntns", + "root", + "workingDir", + } +} + +func (l *vfsLookup) beforeSave() {} + +// +checklocksignore +func (l *vfsLookup) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.mntns) + stateSinkObject.Save(1, &l.root) + stateSinkObject.Save(2, &l.workingDir) +} + +func (l *vfsLookup) afterLoad() {} + +// +checklocksignore +func (l *vfsLookup) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.mntns) + stateSourceObject.Load(1, &l.root) + stateSourceObject.Load(2, &l.workingDir) +} + +func init() { + state.Register((*fsFile)(nil)) + state.Register((*fsLookup)(nil)) + state.Register((*VFSFile)(nil)) + state.Register((*vfsLookup)(nil)) +} diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD deleted file mode 100644 index e5fdcc776..000000000 --- a/pkg/sentry/fsimpl/cgroupfs/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dir_refs", - out = "dir_refs.go", - package = "cgroupfs", - prefix = "dir", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "dir", - }, -) - -go_library( - name = "cgroupfs", - srcs = [ - "base.go", - "cgroupfs.go", - "cpu.go", - "cpuacct.go", - "cpuset.go", - "dir_refs.go", - "job.go", - "memory.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/coverage", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sentry/arch", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs_state_autogen.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs_state_autogen.go new file mode 100644 index 000000000..aa40bb193 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs_state_autogen.go @@ -0,0 +1,696 @@ +// automatically generated by stateify. + +package cgroupfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (c *controllerCommon) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.controllerCommon" +} + +func (c *controllerCommon) StateFields() []string { + return []string{ + "ty", + "fs", + } +} + +func (c *controllerCommon) beforeSave() {} + +// +checklocksignore +func (c *controllerCommon) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.ty) + stateSinkObject.Save(1, &c.fs) +} + +func (c *controllerCommon) afterLoad() {} + +// +checklocksignore +func (c *controllerCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.ty) + stateSourceObject.Load(1, &c.fs) +} + +func (c *cgroupInode) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cgroupInode" +} + +func (c *cgroupInode) StateFields() []string { + return []string{ + "dir", + "ts", + } +} + +func (c *cgroupInode) beforeSave() {} + +// +checklocksignore +func (c *cgroupInode) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.dir) + stateSinkObject.Save(1, &c.ts) +} + +func (c *cgroupInode) afterLoad() {} + +// +checklocksignore +func (c *cgroupInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.dir) + stateSourceObject.Load(1, &c.ts) +} + +func (d *cgroupProcsData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cgroupProcsData" +} + +func (d *cgroupProcsData) StateFields() []string { + return []string{ + "cgroupInode", + } +} + +func (d *cgroupProcsData) beforeSave() {} + +// +checklocksignore +func (d *cgroupProcsData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.cgroupInode) +} + +func (d *cgroupProcsData) afterLoad() {} + +// +checklocksignore +func (d *cgroupProcsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.cgroupInode) +} + +func (d *tasksData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.tasksData" +} + +func (d *tasksData) StateFields() []string { + return []string{ + "cgroupInode", + } +} + +func (d *tasksData) beforeSave() {} + +// +checklocksignore +func (d *tasksData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.cgroupInode) +} + +func (d *tasksData) afterLoad() {} + +// +checklocksignore +func (d *tasksData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.cgroupInode) +} + +func (fsType *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.FilesystemType" +} + +func (fsType *FilesystemType) StateFields() []string { + return []string{} +} + +func (fsType *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fsType *FilesystemType) StateSave(stateSinkObject state.Sink) { + fsType.beforeSave() +} + +func (fsType *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fsType *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (i *InternalData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.InternalData" +} + +func (i *InternalData) StateFields() []string { + return []string{ + "DefaultControlValues", + "InitialCgroupPath", + } +} + +func (i *InternalData) beforeSave() {} + +// +checklocksignore +func (i *InternalData) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.DefaultControlValues) + stateSinkObject.Save(1, &i.InitialCgroupPath) +} + +func (i *InternalData) afterLoad() {} + +// +checklocksignore +func (i *InternalData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.DefaultControlValues) + stateSourceObject.Load(1, &i.InitialCgroupPath) +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + "hierarchyID", + "controllers", + "kcontrollers", + "numCgroups", + "root", + "effectiveRoot", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) + stateSinkObject.Save(2, &fs.hierarchyID) + stateSinkObject.Save(3, &fs.controllers) + stateSinkObject.Save(4, &fs.kcontrollers) + stateSinkObject.Save(5, &fs.numCgroups) + stateSinkObject.Save(6, &fs.root) + stateSinkObject.Save(7, &fs.effectiveRoot) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) + stateSourceObject.Load(2, &fs.hierarchyID) + stateSourceObject.Load(3, &fs.controllers) + stateSourceObject.Load(4, &fs.kcontrollers) + stateSourceObject.Load(5, &fs.numCgroups) + stateSourceObject.Load(6, &fs.root) + stateSourceObject.Load(7, &fs.effectiveRoot) +} + +func (i *implStatFS) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.implStatFS" +} + +func (i *implStatFS) StateFields() []string { + return []string{} +} + +func (i *implStatFS) beforeSave() {} + +// +checklocksignore +func (i *implStatFS) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *implStatFS) afterLoad() {} + +// +checklocksignore +func (i *implStatFS) StateLoad(stateSourceObject state.Source) { +} + +func (d *dir) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.dir" +} + +func (d *dir) StateFields() []string { + return []string{ + "InodeNoopRefCount", + "InodeAlwaysValid", + "InodeAttrs", + "InodeNotSymlink", + "InodeDirectoryNoNewChildren", + "OrderedChildren", + "implStatFS", + "locks", + "fs", + "cgi", + } +} + +func (d *dir) beforeSave() {} + +// +checklocksignore +func (d *dir) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.InodeNoopRefCount) + stateSinkObject.Save(1, &d.InodeAlwaysValid) + stateSinkObject.Save(2, &d.InodeAttrs) + stateSinkObject.Save(3, &d.InodeNotSymlink) + stateSinkObject.Save(4, &d.InodeDirectoryNoNewChildren) + stateSinkObject.Save(5, &d.OrderedChildren) + stateSinkObject.Save(6, &d.implStatFS) + stateSinkObject.Save(7, &d.locks) + stateSinkObject.Save(8, &d.fs) + stateSinkObject.Save(9, &d.cgi) +} + +func (d *dir) afterLoad() {} + +// +checklocksignore +func (d *dir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.InodeNoopRefCount) + stateSourceObject.Load(1, &d.InodeAlwaysValid) + stateSourceObject.Load(2, &d.InodeAttrs) + stateSourceObject.Load(3, &d.InodeNotSymlink) + stateSourceObject.Load(4, &d.InodeDirectoryNoNewChildren) + stateSourceObject.Load(5, &d.OrderedChildren) + stateSourceObject.Load(6, &d.implStatFS) + stateSourceObject.Load(7, &d.locks) + stateSourceObject.Load(8, &d.fs) + stateSourceObject.Load(9, &d.cgi) +} + +func (c *controllerFile) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.controllerFile" +} + +func (c *controllerFile) StateFields() []string { + return []string{ + "DynamicBytesFile", + } +} + +func (c *controllerFile) beforeSave() {} + +// +checklocksignore +func (c *controllerFile) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.DynamicBytesFile) +} + +func (c *controllerFile) afterLoad() {} + +// +checklocksignore +func (c *controllerFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.DynamicBytesFile) +} + +func (s *staticControllerFile) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.staticControllerFile" +} + +func (s *staticControllerFile) StateFields() []string { + return []string{ + "DynamicBytesFile", + "StaticData", + } +} + +func (s *staticControllerFile) beforeSave() {} + +// +checklocksignore +func (s *staticControllerFile) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.DynamicBytesFile) + stateSinkObject.Save(1, &s.StaticData) +} + +func (s *staticControllerFile) afterLoad() {} + +// +checklocksignore +func (s *staticControllerFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.DynamicBytesFile) + stateSourceObject.Load(1, &s.StaticData) +} + +func (c *cpuController) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuController" +} + +func (c *cpuController) StateFields() []string { + return []string{ + "controllerCommon", + "cfsPeriod", + "cfsQuota", + "shares", + } +} + +func (c *cpuController) beforeSave() {} + +// +checklocksignore +func (c *cpuController) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.controllerCommon) + stateSinkObject.Save(1, &c.cfsPeriod) + stateSinkObject.Save(2, &c.cfsQuota) + stateSinkObject.Save(3, &c.shares) +} + +func (c *cpuController) afterLoad() {} + +// +checklocksignore +func (c *cpuController) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.controllerCommon) + stateSourceObject.Load(1, &c.cfsPeriod) + stateSourceObject.Load(2, &c.cfsQuota) + stateSourceObject.Load(3, &c.shares) +} + +func (c *cpuacctController) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuacctController" +} + +func (c *cpuacctController) StateFields() []string { + return []string{ + "controllerCommon", + } +} + +func (c *cpuacctController) beforeSave() {} + +// +checklocksignore +func (c *cpuacctController) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.controllerCommon) +} + +func (c *cpuacctController) afterLoad() {} + +// +checklocksignore +func (c *cpuacctController) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.controllerCommon) +} + +func (c *cpuacctCgroup) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuacctCgroup" +} + +func (c *cpuacctCgroup) StateFields() []string { + return []string{ + "cgroupInode", + } +} + +func (c *cpuacctCgroup) beforeSave() {} + +// +checklocksignore +func (c *cpuacctCgroup) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.cgroupInode) +} + +func (c *cpuacctCgroup) afterLoad() {} + +// +checklocksignore +func (c *cpuacctCgroup) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.cgroupInode) +} + +func (d *cpuacctStatData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuacctStatData" +} + +func (d *cpuacctStatData) StateFields() []string { + return []string{ + "cpuacctCgroup", + } +} + +func (d *cpuacctStatData) beforeSave() {} + +// +checklocksignore +func (d *cpuacctStatData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.cpuacctCgroup) +} + +func (d *cpuacctStatData) afterLoad() {} + +// +checklocksignore +func (d *cpuacctStatData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.cpuacctCgroup) +} + +func (d *cpuacctUsageData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuacctUsageData" +} + +func (d *cpuacctUsageData) StateFields() []string { + return []string{ + "cpuacctCgroup", + } +} + +func (d *cpuacctUsageData) beforeSave() {} + +// +checklocksignore +func (d *cpuacctUsageData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.cpuacctCgroup) +} + +func (d *cpuacctUsageData) afterLoad() {} + +// +checklocksignore +func (d *cpuacctUsageData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.cpuacctCgroup) +} + +func (d *cpuacctUsageUserData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuacctUsageUserData" +} + +func (d *cpuacctUsageUserData) StateFields() []string { + return []string{ + "cpuacctCgroup", + } +} + +func (d *cpuacctUsageUserData) beforeSave() {} + +// +checklocksignore +func (d *cpuacctUsageUserData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.cpuacctCgroup) +} + +func (d *cpuacctUsageUserData) afterLoad() {} + +// +checklocksignore +func (d *cpuacctUsageUserData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.cpuacctCgroup) +} + +func (d *cpuacctUsageSysData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpuacctUsageSysData" +} + +func (d *cpuacctUsageSysData) StateFields() []string { + return []string{ + "cpuacctCgroup", + } +} + +func (d *cpuacctUsageSysData) beforeSave() {} + +// +checklocksignore +func (d *cpuacctUsageSysData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.cpuacctCgroup) +} + +func (d *cpuacctUsageSysData) afterLoad() {} + +// +checklocksignore +func (d *cpuacctUsageSysData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.cpuacctCgroup) +} + +func (c *cpusetController) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.cpusetController" +} + +func (c *cpusetController) StateFields() []string { + return []string{ + "controllerCommon", + } +} + +func (c *cpusetController) beforeSave() {} + +// +checklocksignore +func (c *cpusetController) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.controllerCommon) +} + +func (c *cpusetController) afterLoad() {} + +// +checklocksignore +func (c *cpusetController) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.controllerCommon) +} + +func (r *dirRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.dirRefs" +} + +func (r *dirRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *dirRefs) beforeSave() {} + +// +checklocksignore +func (r *dirRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *dirRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (c *jobController) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.jobController" +} + +func (c *jobController) StateFields() []string { + return []string{ + "controllerCommon", + "id", + } +} + +func (c *jobController) beforeSave() {} + +// +checklocksignore +func (c *jobController) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.controllerCommon) + stateSinkObject.Save(1, &c.id) +} + +func (c *jobController) afterLoad() {} + +// +checklocksignore +func (c *jobController) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.controllerCommon) + stateSourceObject.Load(1, &c.id) +} + +func (d *jobIDData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.jobIDData" +} + +func (d *jobIDData) StateFields() []string { + return []string{ + "c", + } +} + +func (d *jobIDData) beforeSave() {} + +// +checklocksignore +func (d *jobIDData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.c) +} + +func (d *jobIDData) afterLoad() {} + +// +checklocksignore +func (d *jobIDData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.c) +} + +func (c *memoryController) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.memoryController" +} + +func (c *memoryController) StateFields() []string { + return []string{ + "controllerCommon", + "limitBytes", + } +} + +func (c *memoryController) beforeSave() {} + +// +checklocksignore +func (c *memoryController) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.controllerCommon) + stateSinkObject.Save(1, &c.limitBytes) +} + +func (c *memoryController) afterLoad() {} + +// +checklocksignore +func (c *memoryController) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.controllerCommon) + stateSourceObject.Load(1, &c.limitBytes) +} + +func (d *memoryUsageInBytesData) StateTypeName() string { + return "pkg/sentry/fsimpl/cgroupfs.memoryUsageInBytesData" +} + +func (d *memoryUsageInBytesData) StateFields() []string { + return []string{} +} + +func (d *memoryUsageInBytesData) beforeSave() {} + +// +checklocksignore +func (d *memoryUsageInBytesData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() +} + +func (d *memoryUsageInBytesData) afterLoad() {} + +// +checklocksignore +func (d *memoryUsageInBytesData) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*controllerCommon)(nil)) + state.Register((*cgroupInode)(nil)) + state.Register((*cgroupProcsData)(nil)) + state.Register((*tasksData)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*InternalData)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*implStatFS)(nil)) + state.Register((*dir)(nil)) + state.Register((*controllerFile)(nil)) + state.Register((*staticControllerFile)(nil)) + state.Register((*cpuController)(nil)) + state.Register((*cpuacctController)(nil)) + state.Register((*cpuacctCgroup)(nil)) + state.Register((*cpuacctStatData)(nil)) + state.Register((*cpuacctUsageData)(nil)) + state.Register((*cpuacctUsageUserData)(nil)) + state.Register((*cpuacctUsageSysData)(nil)) + state.Register((*cpusetController)(nil)) + state.Register((*dirRefs)(nil)) + state.Register((*jobController)(nil)) + state.Register((*jobIDData)(nil)) + state.Register((*memoryController)(nil)) + state.Register((*memoryUsageInBytesData)(nil)) +} diff --git a/pkg/refsvfs2/refs_template.go b/pkg/sentry/fsimpl/cgroupfs/dir_refs.go index 55b0a60a1..c29f0c9ae 100644 --- a/pkg/refsvfs2/refs_template.go +++ b/pkg/sentry/fsimpl/cgroupfs/dir_refs.go @@ -1,20 +1,4 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package refs_template defines a template that can be used by reference -// counted objects. The template comes with leak checking capabilities. -package refs_template +package cgroupfs import ( "fmt" @@ -27,15 +11,11 @@ import ( // stack traces). This is false by default and should only be set to true for // debugging purposes, as it can generate an extremely large amount of output // and drastically degrade performance. -const enableLogging = false - -// T is the type of the reference counted object. It is only used to customize -// debug output when leak checking. -type T interface{} +const direnableLogging = false // obj is used to customize logging. Note that we use a pointer to T so that // we do not copy the entire object when passed as a format parameter. -var obj *T +var dirobj *dir // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. @@ -49,7 +29,7 @@ var obj *T // interfaces manually. // // +stateify savable -type Refs struct { +type dirRefs struct { // refCount is composed of two fields: // // [32-bit speculative references]:[32-bit real references] @@ -62,38 +42,38 @@ type Refs struct { // InitRefs initializes r with one reference and, if enabled, activates leak // checking. -func (r *Refs) InitRefs() { +func (r *dirRefs) InitRefs() { atomic.StoreInt64(&r.refCount, 1) refsvfs2.Register(r) } // RefType implements refsvfs2.CheckedObject.RefType. -func (r *Refs) RefType() string { - return fmt.Sprintf("%T", obj)[1:] +func (r *dirRefs) RefType() string { + return fmt.Sprintf("%T", dirobj)[1:] } // LeakMessage implements refsvfs2.CheckedObject.LeakMessage. -func (r *Refs) LeakMessage() string { +func (r *dirRefs) LeakMessage() string { return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) } // LogRefs implements refsvfs2.CheckedObject.LogRefs. -func (r *Refs) LogRefs() bool { - return enableLogging +func (r *dirRefs) LogRefs() bool { + return direnableLogging } // ReadRefs returns the current number of references. The returned count is // inherently racy and is unsafe to use without external synchronization. -func (r *Refs) ReadRefs() int64 { +func (r *dirRefs) ReadRefs() int64 { return atomic.LoadInt64(&r.refCount) } // IncRef implements refs.RefCounter.IncRef. // //go:nosplit -func (r *Refs) IncRef() { +func (r *dirRefs) IncRef() { v := atomic.AddInt64(&r.refCount, 1) - if enableLogging { + if direnableLogging { refsvfs2.LogIncRef(r, v) } if v <= 1 { @@ -108,17 +88,16 @@ func (r *Refs) IncRef() { // other TryIncRef calls from genuine references held. // //go:nosplit -func (r *Refs) TryIncRef() bool { +func (r *dirRefs) TryIncRef() bool { const speculativeRef = 1 << 32 if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { - // This object has already been freed. + atomic.AddInt64(&r.refCount, -speculativeRef) return false } - // Turn into a real reference. v := atomic.AddInt64(&r.refCount, -speculativeRef+1) - if enableLogging { + if direnableLogging { refsvfs2.LogTryIncRef(r, v) } return true @@ -136,9 +115,9 @@ func (r *Refs) TryIncRef() bool { // A: TryIncRef [transform speculative to real] // //go:nosplit -func (r *Refs) DecRef(destroy func()) { +func (r *dirRefs) DecRef(destroy func()) { v := atomic.AddInt64(&r.refCount, -1) - if enableLogging { + if direnableLogging { refsvfs2.LogDecRef(r, v) } switch { @@ -147,14 +126,14 @@ func (r *Refs) DecRef(destroy func()) { case v == 0: refsvfs2.Unregister(r) - // Call the destructor. + if destroy != nil { destroy() } } } -func (r *Refs) afterLoad() { +func (r *dirRefs) afterLoad() { if r.ReadRefs() > 0 { refsvfs2.Register(r) } diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD deleted file mode 100644 index e0b879339..000000000 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "root_inode_refs", - out = "root_inode_refs.go", - package = "devpts", - prefix = "rootInode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "rootInode", - }, -) - -go_library( - name = "devpts", - srcs = [ - "devpts.go", - "line_discipline.go", - "master.go", - "queue.go", - "replica.go", - "root_inode_refs.go", - "terminal.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/unimpl", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "devpts_test", - size = "small", - srcs = ["devpts_test.go"], - library = ":devpts", - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/contexttest", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fsimpl/devpts/devpts_state_autogen.go b/pkg/sentry/fsimpl/devpts/devpts_state_autogen.go new file mode 100644 index 000000000..78e685c7e --- /dev/null +++ b/pkg/sentry/fsimpl/devpts/devpts_state_autogen.go @@ -0,0 +1,502 @@ +// automatically generated by stateify. + +package devpts + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fstype *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.FilesystemType" +} + +func (fstype *FilesystemType) StateFields() []string { + return []string{ + "initErr", + "fs", + "root", + } +} + +func (fstype *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fstype *FilesystemType) StateSave(stateSinkObject state.Sink) { + fstype.beforeSave() + stateSinkObject.Save(0, &fstype.initErr) + stateSinkObject.Save(1, &fstype.fs) + stateSinkObject.Save(2, &fstype.root) +} + +func (fstype *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fstype *FilesystemType) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fstype.initErr) + stateSourceObject.Load(1, &fstype.fs) + stateSourceObject.Load(2, &fstype.root) +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (i *rootInode) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.rootInode" +} + +func (i *rootInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + "rootInodeRefs", + "locks", + "master", + "replicas", + "nextIdx", + } +} + +func (i *rootInode) beforeSave() {} + +// +checklocksignore +func (i *rootInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.implStatFS) + stateSinkObject.Save(1, &i.InodeAlwaysValid) + stateSinkObject.Save(2, &i.InodeAttrs) + stateSinkObject.Save(3, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(4, &i.InodeNotSymlink) + stateSinkObject.Save(5, &i.InodeTemporary) + stateSinkObject.Save(6, &i.OrderedChildren) + stateSinkObject.Save(7, &i.rootInodeRefs) + stateSinkObject.Save(8, &i.locks) + stateSinkObject.Save(9, &i.master) + stateSinkObject.Save(10, &i.replicas) + stateSinkObject.Save(11, &i.nextIdx) +} + +func (i *rootInode) afterLoad() {} + +// +checklocksignore +func (i *rootInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.implStatFS) + stateSourceObject.Load(1, &i.InodeAlwaysValid) + stateSourceObject.Load(2, &i.InodeAttrs) + stateSourceObject.Load(3, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(4, &i.InodeNotSymlink) + stateSourceObject.Load(5, &i.InodeTemporary) + stateSourceObject.Load(6, &i.OrderedChildren) + stateSourceObject.Load(7, &i.rootInodeRefs) + stateSourceObject.Load(8, &i.locks) + stateSourceObject.Load(9, &i.master) + stateSourceObject.Load(10, &i.replicas) + stateSourceObject.Load(11, &i.nextIdx) +} + +func (i *implStatFS) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.implStatFS" +} + +func (i *implStatFS) StateFields() []string { + return []string{} +} + +func (i *implStatFS) beforeSave() {} + +// +checklocksignore +func (i *implStatFS) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *implStatFS) afterLoad() {} + +// +checklocksignore +func (i *implStatFS) StateLoad(stateSourceObject state.Source) { +} + +func (l *lineDiscipline) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.lineDiscipline" +} + +func (l *lineDiscipline) StateFields() []string { + return []string{ + "size", + "inQueue", + "outQueue", + "termios", + "column", + "masterWaiter", + "replicaWaiter", + } +} + +func (l *lineDiscipline) beforeSave() {} + +// +checklocksignore +func (l *lineDiscipline) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.size) + stateSinkObject.Save(1, &l.inQueue) + stateSinkObject.Save(2, &l.outQueue) + stateSinkObject.Save(3, &l.termios) + stateSinkObject.Save(4, &l.column) + stateSinkObject.Save(5, &l.masterWaiter) + stateSinkObject.Save(6, &l.replicaWaiter) +} + +func (l *lineDiscipline) afterLoad() {} + +// +checklocksignore +func (l *lineDiscipline) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.size) + stateSourceObject.Load(1, &l.inQueue) + stateSourceObject.Load(2, &l.outQueue) + stateSourceObject.Load(3, &l.termios) + stateSourceObject.Load(4, &l.column) + stateSourceObject.Load(5, &l.masterWaiter) + stateSourceObject.Load(6, &l.replicaWaiter) +} + +func (o *outputQueueTransformer) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.outputQueueTransformer" +} + +func (o *outputQueueTransformer) StateFields() []string { + return []string{} +} + +func (o *outputQueueTransformer) beforeSave() {} + +// +checklocksignore +func (o *outputQueueTransformer) StateSave(stateSinkObject state.Sink) { + o.beforeSave() +} + +func (o *outputQueueTransformer) afterLoad() {} + +// +checklocksignore +func (o *outputQueueTransformer) StateLoad(stateSourceObject state.Source) { +} + +func (i *inputQueueTransformer) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.inputQueueTransformer" +} + +func (i *inputQueueTransformer) StateFields() []string { + return []string{} +} + +func (i *inputQueueTransformer) beforeSave() {} + +// +checklocksignore +func (i *inputQueueTransformer) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *inputQueueTransformer) afterLoad() {} + +// +checklocksignore +func (i *inputQueueTransformer) StateLoad(stateSourceObject state.Source) { +} + +func (mi *masterInode) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.masterInode" +} + +func (mi *masterInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "locks", + "root", + } +} + +func (mi *masterInode) beforeSave() {} + +// +checklocksignore +func (mi *masterInode) StateSave(stateSinkObject state.Sink) { + mi.beforeSave() + stateSinkObject.Save(0, &mi.implStatFS) + stateSinkObject.Save(1, &mi.InodeAttrs) + stateSinkObject.Save(2, &mi.InodeNoopRefCount) + stateSinkObject.Save(3, &mi.InodeNotDirectory) + stateSinkObject.Save(4, &mi.InodeNotSymlink) + stateSinkObject.Save(5, &mi.locks) + stateSinkObject.Save(6, &mi.root) +} + +func (mi *masterInode) afterLoad() {} + +// +checklocksignore +func (mi *masterInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mi.implStatFS) + stateSourceObject.Load(1, &mi.InodeAttrs) + stateSourceObject.Load(2, &mi.InodeNoopRefCount) + stateSourceObject.Load(3, &mi.InodeNotDirectory) + stateSourceObject.Load(4, &mi.InodeNotSymlink) + stateSourceObject.Load(5, &mi.locks) + stateSourceObject.Load(6, &mi.root) +} + +func (mfd *masterFileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.masterFileDescription" +} + +func (mfd *masterFileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + "inode", + "t", + } +} + +func (mfd *masterFileDescription) beforeSave() {} + +// +checklocksignore +func (mfd *masterFileDescription) StateSave(stateSinkObject state.Sink) { + mfd.beforeSave() + stateSinkObject.Save(0, &mfd.vfsfd) + stateSinkObject.Save(1, &mfd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &mfd.LockFD) + stateSinkObject.Save(3, &mfd.inode) + stateSinkObject.Save(4, &mfd.t) +} + +func (mfd *masterFileDescription) afterLoad() {} + +// +checklocksignore +func (mfd *masterFileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mfd.vfsfd) + stateSourceObject.Load(1, &mfd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &mfd.LockFD) + stateSourceObject.Load(3, &mfd.inode) + stateSourceObject.Load(4, &mfd.t) +} + +func (q *queue) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.queue" +} + +func (q *queue) StateFields() []string { + return []string{ + "readBuf", + "waitBuf", + "waitBufLen", + "readable", + "transformer", + } +} + +func (q *queue) beforeSave() {} + +// +checklocksignore +func (q *queue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.readBuf) + stateSinkObject.Save(1, &q.waitBuf) + stateSinkObject.Save(2, &q.waitBufLen) + stateSinkObject.Save(3, &q.readable) + stateSinkObject.Save(4, &q.transformer) +} + +func (q *queue) afterLoad() {} + +// +checklocksignore +func (q *queue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.readBuf) + stateSourceObject.Load(1, &q.waitBuf) + stateSourceObject.Load(2, &q.waitBufLen) + stateSourceObject.Load(3, &q.readable) + stateSourceObject.Load(4, &q.transformer) +} + +func (ri *replicaInode) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.replicaInode" +} + +func (ri *replicaInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "locks", + "root", + "t", + } +} + +func (ri *replicaInode) beforeSave() {} + +// +checklocksignore +func (ri *replicaInode) StateSave(stateSinkObject state.Sink) { + ri.beforeSave() + stateSinkObject.Save(0, &ri.implStatFS) + stateSinkObject.Save(1, &ri.InodeAttrs) + stateSinkObject.Save(2, &ri.InodeNoopRefCount) + stateSinkObject.Save(3, &ri.InodeNotDirectory) + stateSinkObject.Save(4, &ri.InodeNotSymlink) + stateSinkObject.Save(5, &ri.locks) + stateSinkObject.Save(6, &ri.root) + stateSinkObject.Save(7, &ri.t) +} + +func (ri *replicaInode) afterLoad() {} + +// +checklocksignore +func (ri *replicaInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ri.implStatFS) + stateSourceObject.Load(1, &ri.InodeAttrs) + stateSourceObject.Load(2, &ri.InodeNoopRefCount) + stateSourceObject.Load(3, &ri.InodeNotDirectory) + stateSourceObject.Load(4, &ri.InodeNotSymlink) + stateSourceObject.Load(5, &ri.locks) + stateSourceObject.Load(6, &ri.root) + stateSourceObject.Load(7, &ri.t) +} + +func (rfd *replicaFileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.replicaFileDescription" +} + +func (rfd *replicaFileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + "inode", + } +} + +func (rfd *replicaFileDescription) beforeSave() {} + +// +checklocksignore +func (rfd *replicaFileDescription) StateSave(stateSinkObject state.Sink) { + rfd.beforeSave() + stateSinkObject.Save(0, &rfd.vfsfd) + stateSinkObject.Save(1, &rfd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &rfd.LockFD) + stateSinkObject.Save(3, &rfd.inode) +} + +func (rfd *replicaFileDescription) afterLoad() {} + +// +checklocksignore +func (rfd *replicaFileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rfd.vfsfd) + stateSourceObject.Load(1, &rfd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &rfd.LockFD) + stateSourceObject.Load(3, &rfd.inode) +} + +func (r *rootInodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.rootInodeRefs" +} + +func (r *rootInodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *rootInodeRefs) beforeSave() {} + +// +checklocksignore +func (r *rootInodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *rootInodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (tm *Terminal) StateTypeName() string { + return "pkg/sentry/fsimpl/devpts.Terminal" +} + +func (tm *Terminal) StateFields() []string { + return []string{ + "n", + "ld", + "masterKTTY", + "replicaKTTY", + } +} + +func (tm *Terminal) beforeSave() {} + +// +checklocksignore +func (tm *Terminal) StateSave(stateSinkObject state.Sink) { + tm.beforeSave() + stateSinkObject.Save(0, &tm.n) + stateSinkObject.Save(1, &tm.ld) + stateSinkObject.Save(2, &tm.masterKTTY) + stateSinkObject.Save(3, &tm.replicaKTTY) +} + +func (tm *Terminal) afterLoad() {} + +// +checklocksignore +func (tm *Terminal) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tm.n) + stateSourceObject.Load(1, &tm.ld) + stateSourceObject.Load(2, &tm.masterKTTY) + stateSourceObject.Load(3, &tm.replicaKTTY) +} + +func init() { + state.Register((*FilesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*rootInode)(nil)) + state.Register((*implStatFS)(nil)) + state.Register((*lineDiscipline)(nil)) + state.Register((*outputQueueTransformer)(nil)) + state.Register((*inputQueueTransformer)(nil)) + state.Register((*masterInode)(nil)) + state.Register((*masterFileDescription)(nil)) + state.Register((*queue)(nil)) + state.Register((*replicaInode)(nil)) + state.Register((*replicaFileDescription)(nil)) + state.Register((*rootInodeRefs)(nil)) + state.Register((*Terminal)(nil)) +} diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go deleted file mode 100644 index 1ef07d702..000000000 --- a/pkg/sentry/fsimpl/devpts/devpts_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devpts - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestSimpleMasterToReplica(t *testing.T) { - ld := newLineDiscipline(linux.DefaultReplicaTermios) - ctx := contexttest.Context(t) - inBytes := []byte("hello, tty\n") - src := usermem.BytesIOSequence(inBytes) - outBytes := make([]byte, 32) - dst := usermem.BytesIOSequence(outBytes) - - // Write to the input queue. - nw, err := ld.inputQueueWrite(ctx, src) - if err != nil { - t.Fatalf("error writing to input queue: %v", err) - } - if nw != int64(len(inBytes)) { - t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes)) - } - - // Read from the input queue. - nr, err := ld.inputQueueRead(ctx, dst) - if err != nil { - t.Fatalf("error reading from input queue: %v", err) - } - if nr != int64(len(inBytes)) { - t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes)) - } - - outStr := string(outBytes[:nr]) - inStr := string(inBytes) - if outStr != inStr { - t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr) - } -} - -type callback func(*waiter.Entry, waiter.EventMask) - -func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { - cb(entry, mask) -} - -func TestEchoDeadlock(t *testing.T) { - ctx := contexttest.Context(t) - termios := linux.DefaultReplicaTermios - termios.LocalFlags |= linux.ECHO - ld := newLineDiscipline(termios) - outBytes := make([]byte, 32) - dst := usermem.BytesIOSequence(outBytes) - entry := &waiter.Entry{Callback: callback(func(*waiter.Entry, waiter.EventMask) { - ld.inputQueueRead(ctx, dst) - })} - ld.masterWaiter.EventRegister(entry, waiter.ReadableEvents) - defer ld.masterWaiter.EventUnregister(entry) - inBytes := []byte("hello, tty\n") - n, err := ld.inputQueueWrite(ctx, usermem.BytesIOSequence(inBytes)) - if err != nil { - t.Fatalf("inputQueueWrite: %v", err) - } - if int(n) != len(inBytes) { - t.Fatalf("read wrong length: got %d, want %d", n, len(inBytes)) - } - outStr := string(outBytes[:n]) - inStr := string(inBytes) - if outStr != inStr { - t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr) - } -} diff --git a/pkg/sentry/fsimpl/devpts/root_inode_refs.go b/pkg/sentry/fsimpl/devpts/root_inode_refs.go new file mode 100644 index 000000000..e53739a90 --- /dev/null +++ b/pkg/sentry/fsimpl/devpts/root_inode_refs.go @@ -0,0 +1,140 @@ +package devpts + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const rootInodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var rootInodeobj *rootInode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type rootInodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *rootInodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *rootInodeRefs) RefType() string { + return fmt.Sprintf("%T", rootInodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *rootInodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *rootInodeRefs) LogRefs() bool { + return rootInodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *rootInodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *rootInodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if rootInodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *rootInodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if rootInodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *rootInodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if rootInodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *rootInodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD deleted file mode 100644 index e49a04c1b..000000000 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "devtmpfs", - srcs = [ - "devtmpfs.go", - "save_restore.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/sync", - ], -) - -go_test( - name = "devtmpfs_test", - size = "small", - srcs = ["devtmpfs_test.go"], - library = ":devtmpfs", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_state_autogen.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_state_autogen.go new file mode 100644 index 000000000..900c7d8fe --- /dev/null +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_state_autogen.go @@ -0,0 +1,41 @@ +// automatically generated by stateify. + +package devtmpfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fst *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/devtmpfs.FilesystemType" +} + +func (fst *FilesystemType) StateFields() []string { + return []string{ + "initErr", + "fs", + "root", + } +} + +func (fst *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fst *FilesystemType) StateSave(stateSinkObject state.Sink) { + fst.beforeSave() + stateSinkObject.Save(0, &fst.initErr) + stateSinkObject.Save(1, &fst.fs) + stateSinkObject.Save(2, &fst.root) +} + +// +checklocksignore +func (fst *FilesystemType) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fst.initErr) + stateSourceObject.Load(1, &fst.fs) + stateSourceObject.Load(2, &fst.root) + stateSourceObject.AfterLoad(fst.afterLoad) +} + +func init() { + state.Register((*FilesystemType)(nil)) +} diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go deleted file mode 100644 index e058eda7a..000000000 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devtmpfs - -import ( - "path" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -const devPath = "/dev" - -func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry, func()) { - t.Helper() - - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - // Register tmpfs just so that we can have a root filesystem that isn't - // devtmpfs. - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - vfsObj.MustRegisterFilesystemType("devtmpfs", &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - // Create a test mount namespace with devtmpfs mounted at "/dev". - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.MountOptions{}) - if err != nil { - t.Fatalf("failed to create tmpfs root mount: %v", err) - } - root := mntns.Root() - root.IncRef() - devpop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(devPath), - } - if err := vfsObj.MkdirAt(ctx, creds, &devpop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - t.Fatalf("failed to create mount point: %v", err) - } - if _, err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil { - t.Fatalf("failed to mount devtmpfs: %v", err) - } - - return ctx, creds, vfsObj, root, func() { - root.DecRef(ctx) - mntns.DecRef(ctx) - } -} - -func TestUserspaceInit(t *testing.T) { - ctx, creds, vfsObj, root, cleanup := setupDevtmpfs(t) - defer cleanup() - - a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") - if err != nil { - t.Fatalf("failed to create devtmpfs.Accessor: %v", err) - } - defer a.Release(ctx) - - // Create "userspace-initialized" files using a devtmpfs.Accessor. - if err := a.UserspaceInit(ctx); err != nil { - t.Fatalf("failed to userspace-initialize devtmpfs: %v", err) - } - - // Created files should be visible in the test mount namespace. - links := []struct { - source string - target string - }{ - { - source: "fd", - target: "/proc/self/fd", - }, - { - source: "stdin", - target: "/proc/self/fd/0", - }, - { - source: "stdout", - target: "/proc/self/fd/1", - }, - { - source: "stderr", - target: "/proc/self/fd/2", - }, - { - source: "ptmx", - target: "pts/ptmx", - }, - } - - for _, link := range links { - abspath := path.Join(devPath, link.source) - if gotTarget, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }); err != nil || gotTarget != link.target { - t.Errorf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, gotTarget, err, link.target) - } - } - - dirs := []string{"shm", "pts"} - for _, dir := range dirs { - abspath := path.Join(devPath, dir) - statx, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }, &vfs.StatOptions{ - Mask: linux.STATX_MODE, - }) - if err != nil { - t.Errorf("stat(%q): got error %v ", abspath, err) - continue - } - if want := uint16(0755) | linux.S_IFDIR; statx.Mode != want { - t.Errorf("stat(%q): got mode %x, want %x", abspath, statx.Mode, want) - } - } -} - -func TestCreateDeviceFile(t *testing.T) { - ctx, creds, vfsObj, root, cleanup := setupDevtmpfs(t) - defer cleanup() - - a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") - if err != nil { - t.Fatalf("failed to create devtmpfs.Accessor: %v", err) - } - defer a.Release(ctx) - - devFiles := []struct { - path string - kind vfs.DeviceKind - major uint32 - minor uint32 - perms uint16 - }{ - { - path: "dummy", - kind: vfs.CharDevice, - major: 12, - minor: 34, - perms: 0600, - }, - { - path: "foo/bar", - kind: vfs.BlockDevice, - major: 13, - minor: 35, - perms: 0660, - }, - { - path: "foo/baz", - kind: vfs.CharDevice, - major: 12, - minor: 40, - perms: 0666, - }, - { - path: "a/b/c/d/e", - kind: vfs.BlockDevice, - major: 12, - minor: 34, - perms: 0600, - }, - } - - for _, f := range devFiles { - if err := a.CreateDeviceFile(ctx, f.path, f.kind, f.major, f.minor, f.perms); err != nil { - t.Fatalf("failed to create device file: %v", err) - } - // The device special file should be visible in the test mount namespace. - abspath := path.Join(devPath, f.path) - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }, &vfs.StatOptions{ - Mask: linux.STATX_TYPE | linux.STATX_MODE, - }) - if err != nil { - t.Fatalf("failed to stat device file at %q: %v", abspath, err) - } - if stat.RdevMajor != f.major { - t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, f.major) - } - if stat.RdevMinor != f.minor { - t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, f.minor) - } - wantMode := f.perms - switch f.kind { - case vfs.CharDevice: - wantMode |= linux.S_IFCHR - case vfs.BlockDevice: - wantMode |= linux.S_IFBLK - } - if stat.Mode != wantMode { - t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode) - } - } -} diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD deleted file mode 100644 index 1cb049a29..000000000 --- a/pkg/sentry/fsimpl/eventfd/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "eventfd", - srcs = ["eventfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fdnotifier", - "//pkg/hostarch", - "//pkg/log", - "//pkg/sentry/vfs", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "eventfd_test", - size = "small", - srcs = ["eventfd_test.go"], - library = ":eventfd", - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/contexttest", - "//pkg/sentry/vfs", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_state_autogen.go b/pkg/sentry/fsimpl/eventfd/eventfd_state_autogen.go new file mode 100644 index 000000000..93de7f32e --- /dev/null +++ b/pkg/sentry/fsimpl/eventfd/eventfd_state_autogen.go @@ -0,0 +1,57 @@ +// automatically generated by stateify. + +package eventfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (efd *EventFileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/eventfd.EventFileDescription" +} + +func (efd *EventFileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "queue", + "val", + "semMode", + "hostfd", + } +} + +func (efd *EventFileDescription) beforeSave() {} + +// +checklocksignore +func (efd *EventFileDescription) StateSave(stateSinkObject state.Sink) { + efd.beforeSave() + stateSinkObject.Save(0, &efd.vfsfd) + stateSinkObject.Save(1, &efd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &efd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &efd.NoLockFD) + stateSinkObject.Save(4, &efd.queue) + stateSinkObject.Save(5, &efd.val) + stateSinkObject.Save(6, &efd.semMode) + stateSinkObject.Save(7, &efd.hostfd) +} + +func (efd *EventFileDescription) afterLoad() {} + +// +checklocksignore +func (efd *EventFileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &efd.vfsfd) + stateSourceObject.Load(1, &efd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &efd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &efd.NoLockFD) + stateSourceObject.Load(4, &efd.queue) + stateSourceObject.Load(5, &efd.val) + stateSourceObject.Load(6, &efd.semMode) + stateSourceObject.Load(7, &efd.hostfd) +} + +func init() { + state.Register((*EventFileDescription)(nil)) +} diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_test.go b/pkg/sentry/fsimpl/eventfd/eventfd_test.go deleted file mode 100644 index 85718f813..000000000 --- a/pkg/sentry/fsimpl/eventfd/eventfd_test.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eventfd - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestEventFD(t *testing.T) { - initVals := []uint64{ - 0, - // Using a non-zero initial value verifies that writing to an - // eventfd signals when the eventfd's counter was already - // non-zero. - 343, - } - - for _, initVal := range initVals { - ctx := contexttest.Context(t) - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - - // Make a new eventfd that is writable. - eventfd, err := New(ctx, vfsObj, initVal, false, linux.O_RDWR) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - defer eventfd.DecRef(ctx) - - // Register a callback for a write event. - w, ch := waiter.NewChannelEntry(nil) - eventfd.EventRegister(&w, waiter.ReadableEvents) - defer eventfd.EventUnregister(&w) - - data := []byte("00000124") - // Create and submit a write request. - n, err := eventfd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatal(err) - } - if n != 8 { - t.Errorf("eventfd.write wrote %d bytes, not full int64", n) - } - - // Check if the callback fired due to the write event. - select { - case <-ch: - default: - t.Errorf("Didn't get notified of EventIn after write") - } - } -} - -func TestEventFDStat(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - - // Make a new eventfd that is writable. - eventfd, err := New(ctx, vfsObj, 0, false, linux.O_RDWR) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - defer eventfd.DecRef(ctx) - - statx, err := eventfd.Stat(ctx, vfs.StatOptions{ - Mask: linux.STATX_BASIC_STATS, - }) - if err != nil { - t.Fatalf("eventfd.Stat failed: %v", err) - } - if statx.Size != 0 { - t.Errorf("eventfd size should be 0") - } -} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD deleted file mode 100644 index e69de29bb..000000000 --- a/pkg/sentry/fsimpl/ext/BUILD +++ /dev/null diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD deleted file mode 100644 index 05c4fbeb2..000000000 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ /dev/null @@ -1,90 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "request_list", - out = "request_list.go", - package = "fuse", - prefix = "request", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Request", - "Linker": "*Request", - }, -) - -go_template_instance( - name = "inode_refs", - out = "inode_refs.go", - package = "fuse", - prefix = "inode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "inode", - }, -) - -go_library( - name = "fuse", - srcs = [ - "connection.go", - "connection_control.go", - "dev.go", - "directory.go", - "file.go", - "fusefs.go", - "inode_refs.go", - "read_write.go", - "register.go", - "regular_file.go", - "request_list.go", - "request_response.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/fsimpl/devtmpfs", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fuse_test", - size = "small", - srcs = [ - "connection_test.go", - "dev_test.go", - "utils_test.go", - ], - library = ":fuse", - deps = [ - "//pkg/abi/linux", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go deleted file mode 100644 index 1fddd858e..000000000 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fuse - -import ( - "math/rand" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" -) - -// TestConnectionInitBlock tests if initialization -// correctly blocks and unblocks the connection. -// Since it's unfeasible to test kernelTask.Block() in unit test, -// the code in Call() are not tested here. -func TestConnectionInitBlock(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - - conn, _, err := newTestConnection(s, k, maxActiveRequestsDefault) - if err != nil { - t.Fatalf("newTestConnection: %v", err) - } - - select { - case <-conn.initializedChan: - t.Fatalf("initializedChan should be blocking before SetInitialized") - default: - } - - conn.SetInitialized() - - select { - case <-conn.initializedChan: - default: - t.Fatalf("initializedChan should not be blocking after SetInitialized") - } -} - -func TestConnectionAbort(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - creds := auth.CredentialsFromContext(s.Ctx) - task := kernel.TaskFromContext(s.Ctx) - - const numRequests uint64 = 256 - - conn, _, err := newTestConnection(s, k, numRequests) - if err != nil { - t.Fatalf("newTestConnection: %v", err) - } - - testObj := &testPayload{ - data: rand.Uint32(), - } - - var futNormal []*futureResponse - - for i := 0; i < int(numRequests); i++ { - req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) - fut, err := conn.callFutureLocked(task, req) - if err != nil { - t.Fatalf("callFutureLocked failed: %v", err) - } - futNormal = append(futNormal, fut) - } - - conn.Abort(s.Ctx) - - // Abort should unblock the initialization channel. - // Note: no test requests are actually blocked on `conn.initializedChan`. - select { - case <-conn.initializedChan: - default: - t.Fatalf("initializedChan should not be blocking after SetInitialized") - } - - // Abort will return ECONNABORTED error to unblocked requests. - for _, fut := range futNormal { - if fut.getResponse().hdr.Error != -int32(unix.ECONNABORTED) { - t.Fatalf("Incorrect error code received for aborted connection: %v", fut.getResponse().hdr.Error) - } - } - - // After abort, Call() should return directly with ENOTCONN. - req := conn.NewRequest(creds, 0, 0, 0, testObj) - _, err = conn.Call(task, req) - if !linuxerr.Equals(linuxerr.ENOTCONN, err) { - t.Fatalf("Incorrect error code received for Call() after connection aborted") - } - -} diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go deleted file mode 100644 index 8951b5ba8..000000000 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fuse - -import ( - "fmt" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -// echoTestOpcode is the Opcode used during testing. The server used in tests -// will simply echo the payload back with the appropriate headers. -const echoTestOpcode linux.FUSEOpcode = 1000 - -// TestFUSECommunication tests that the communication layer between the Sentry and the -// FUSE server daemon works as expected. -func TestFUSECommunication(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - creds := auth.CredentialsFromContext(s.Ctx) - - // Create test cases with different number of concurrent clients and servers. - testCases := []struct { - Name string - NumClients int - NumServers int - MaxActiveRequests uint64 - }{ - { - Name: "SingleClientSingleServer", - NumClients: 1, - NumServers: 1, - MaxActiveRequests: maxActiveRequestsDefault, - }, - { - Name: "SingleClientMultipleServers", - NumClients: 1, - NumServers: 10, - MaxActiveRequests: maxActiveRequestsDefault, - }, - { - Name: "MultipleClientsSingleServer", - NumClients: 10, - NumServers: 1, - MaxActiveRequests: maxActiveRequestsDefault, - }, - { - Name: "MultipleClientsMultipleServers", - NumClients: 10, - NumServers: 10, - MaxActiveRequests: maxActiveRequestsDefault, - }, - { - Name: "RequestCapacityFull", - NumClients: 10, - NumServers: 1, - MaxActiveRequests: 1, - }, - { - Name: "RequestCapacityContinuouslyFull", - NumClients: 100, - NumServers: 2, - MaxActiveRequests: 2, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests) - if err != nil { - t.Fatalf("newTestConnection: %v", err) - } - - clientsDone := make([]chan struct{}, testCase.NumClients) - serversDone := make([]chan struct{}, testCase.NumServers) - serversKill := make([]chan struct{}, testCase.NumServers) - - // FUSE clients. - for i := 0; i < testCase.NumClients; i++ { - clientsDone[i] = make(chan struct{}) - go func(i int) { - fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i]) - }(i) - } - - // FUSE servers. - for j := 0; j < testCase.NumServers; j++ { - serversDone[j] = make(chan struct{}) - serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block. - go func(j int) { - fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j]) - }(j) - } - - // Tear down. - // - // Make sure all the clients are done. - for i := 0; i < testCase.NumClients; i++ { - <-clientsDone[i] - } - - // Kill any server that is potentially waiting. - for j := 0; j < testCase.NumServers; j++ { - serversKill[j] <- struct{}{} - } - - // Make sure all the servers are done. - for j := 0; j < testCase.NumServers; j++ { - <-serversDone[j] - } - }) - } -} - -// CallTest makes a request to the server and blocks the invoking -// goroutine until a server responds with a response. Doesn't block -// a kernel.Task. Analogous to Connection.Call but used for testing. -func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) { - conn.fd.mu.Lock() - - // Wait until we're certain that a new request can be processed. - for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { - conn.fd.mu.Unlock() - select { - case <-conn.fd.fullQueueCh: - } - conn.fd.mu.Lock() - } - - fut, err := conn.callFutureLocked(t, r) // No task given. - conn.fd.mu.Unlock() - - if err != nil { - return nil, err - } - - // Resolve the response. - // - // Block without a task. - select { - case <-fut.ch: - } - - // A response is ready. Resolve and return it. - return fut.getResponse(), nil -} - -// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE -// device. However, it does so by - not blocking the task that is calling - and -// instead just waits on a channel. The behaviour is essentially the same as -// DeviceFD.Read except it guarantees that the task is not blocked. -func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) { - var err error - var n, total int64 - - dev := fd.Impl().(*DeviceFD) - - // Register for notifications. - w, ch := waiter.NewChannelEntry(nil) - dev.EventRegister(&w, waiter.ReadableEvents) - for { - // Issue the request and break out if it completes with anything other than - // "would block". - n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{}) - total += n - if err != linuxerr.ErrWouldBlock { - break - } - - // Wait for a notification that we should retry. - // Emulate the blocking for when no requests are available - select { - case <-ch: - case <-killServer: - // Server killed by the main program. - return 0, true, nil - } - } - - dev.EventUnregister(&w) - return total, false, err -} - -// fuseClientRun emulates all the actions of a normal FUSE request. It creates -// a header, a payload, calls the server, waits for the response, and processes -// the response. -func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) { - defer func() { clientDone <- struct{}{} }() - - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatal(err) - } - testObj := &testPayload{ - data: rand.Uint32(), - } - - req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) - - // Queue up a request. - // Analogous to Call except it doesn't block on the task. - resp, err := CallTest(conn, clientTask, req, pid) - if err != nil { - t.Fatalf("CallTaskNonBlock failed: %v", err) - } - - if err = resp.Error(); err != nil { - t.Fatalf("Server responded with an error: %v", err) - } - - var respTestPayload testPayload - if err := resp.UnmarshalPayload(&respTestPayload); err != nil { - t.Fatalf("Unmarshalling payload error: %v", err) - } - - if resp.hdr.Unique != req.hdr.Unique { - t.Fatalf("got response for another request. Expected response for req %v but got response for req %v", - req.hdr.Unique, resp.hdr.Unique) - } - - if respTestPayload.data != testObj.data { - t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) - } - -} - -// fuseServerRun creates a task and emulates all the actions of a simple FUSE server -// that simply reads a request and echos the same struct back as a response using the -// appropriate headers. -func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) { - defer func() { serverDone <- struct{}{} }() - - // Create the tasks that the server will be using. - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - var readPayload testPayload - - serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatal(err) - } - - // Read the request. - for { - inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) - payloadLen := uint32(readPayload.SizeBytes()) - - // The raed buffer must meet some certain size criteria. - buffSize := inHdrLen + payloadLen - if buffSize < linux.FUSE_MIN_READ_BUFFER { - buffSize = linux.FUSE_MIN_READ_BUFFER - } - inBuf := make([]byte, buffSize) - inIOseq := usermem.BytesIOSequence(inBuf) - - n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer) - if err != nil { - t.Fatalf("Read failed :%v", err) - } - - // Server should shut down. No new requests are going to be made. - if serverKilled { - break - } - - if n <= 0 { - t.Fatalf("Read read no bytes") - } - - var readFUSEHeaderIn linux.FUSEHeaderIn - readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) - readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) - - if readFUSEHeaderIn.Opcode != echoTestOpcode { - t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) - } - - // Write the response. - outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) - outBuf := make([]byte, outHdrLen+payloadLen) - outHeader := linux.FUSEHeaderOut{ - Len: outHdrLen + payloadLen, - Error: 0, - Unique: readFUSEHeaderIn.Unique, - } - - // Echo the payload back. - outHeader.MarshalUnsafe(outBuf[:outHdrLen]) - readPayload.MarshalUnsafe(outBuf[outHdrLen:]) - outIOseq := usermem.BytesIOSequence(outBuf) - - _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed :%v", err) - } - } -} diff --git a/pkg/sentry/fsimpl/fuse/fuse_state_autogen.go b/pkg/sentry/fsimpl/fuse/fuse_state_autogen.go new file mode 100644 index 000000000..711a6c425 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/fuse_state_autogen.go @@ -0,0 +1,562 @@ +// automatically generated by stateify. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (conn *connection) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.connection" +} + +func (conn *connection) StateFields() []string { + return []string{ + "fd", + "attributeVersion", + "initialized", + "initializedChan", + "connected", + "connInitError", + "connInitSuccess", + "aborted", + "numWaiting", + "asyncNum", + "asyncCongestionThreshold", + "asyncNumMax", + "maxRead", + "maxWrite", + "maxPages", + "minor", + "atomicOTrunc", + "asyncRead", + "writebackCache", + "bigWrites", + "dontMask", + "noOpen", + } +} + +func (conn *connection) beforeSave() {} + +// +checklocksignore +func (conn *connection) StateSave(stateSinkObject state.Sink) { + conn.beforeSave() + var initializedChanValue bool + initializedChanValue = conn.saveInitializedChan() + stateSinkObject.SaveValue(3, initializedChanValue) + stateSinkObject.Save(0, &conn.fd) + stateSinkObject.Save(1, &conn.attributeVersion) + stateSinkObject.Save(2, &conn.initialized) + stateSinkObject.Save(4, &conn.connected) + stateSinkObject.Save(5, &conn.connInitError) + stateSinkObject.Save(6, &conn.connInitSuccess) + stateSinkObject.Save(7, &conn.aborted) + stateSinkObject.Save(8, &conn.numWaiting) + stateSinkObject.Save(9, &conn.asyncNum) + stateSinkObject.Save(10, &conn.asyncCongestionThreshold) + stateSinkObject.Save(11, &conn.asyncNumMax) + stateSinkObject.Save(12, &conn.maxRead) + stateSinkObject.Save(13, &conn.maxWrite) + stateSinkObject.Save(14, &conn.maxPages) + stateSinkObject.Save(15, &conn.minor) + stateSinkObject.Save(16, &conn.atomicOTrunc) + stateSinkObject.Save(17, &conn.asyncRead) + stateSinkObject.Save(18, &conn.writebackCache) + stateSinkObject.Save(19, &conn.bigWrites) + stateSinkObject.Save(20, &conn.dontMask) + stateSinkObject.Save(21, &conn.noOpen) +} + +func (conn *connection) afterLoad() {} + +// +checklocksignore +func (conn *connection) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &conn.fd) + stateSourceObject.Load(1, &conn.attributeVersion) + stateSourceObject.Load(2, &conn.initialized) + stateSourceObject.Load(4, &conn.connected) + stateSourceObject.Load(5, &conn.connInitError) + stateSourceObject.Load(6, &conn.connInitSuccess) + stateSourceObject.Load(7, &conn.aborted) + stateSourceObject.Load(8, &conn.numWaiting) + stateSourceObject.Load(9, &conn.asyncNum) + stateSourceObject.Load(10, &conn.asyncCongestionThreshold) + stateSourceObject.Load(11, &conn.asyncNumMax) + stateSourceObject.Load(12, &conn.maxRead) + stateSourceObject.Load(13, &conn.maxWrite) + stateSourceObject.Load(14, &conn.maxPages) + stateSourceObject.Load(15, &conn.minor) + stateSourceObject.Load(16, &conn.atomicOTrunc) + stateSourceObject.Load(17, &conn.asyncRead) + stateSourceObject.Load(18, &conn.writebackCache) + stateSourceObject.Load(19, &conn.bigWrites) + stateSourceObject.Load(20, &conn.dontMask) + stateSourceObject.Load(21, &conn.noOpen) + stateSourceObject.LoadValue(3, new(bool), func(y interface{}) { conn.loadInitializedChan(y.(bool)) }) +} + +func (f *fuseDevice) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.fuseDevice" +} + +func (f *fuseDevice) StateFields() []string { + return []string{} +} + +func (f *fuseDevice) beforeSave() {} + +// +checklocksignore +func (f *fuseDevice) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *fuseDevice) afterLoad() {} + +// +checklocksignore +func (f *fuseDevice) StateLoad(stateSourceObject state.Source) { +} + +func (fd *DeviceFD) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.DeviceFD" +} + +func (fd *DeviceFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "nextOpID", + "queue", + "numActiveRequests", + "completions", + "writeCursor", + "writeBuf", + "writeCursorFR", + "waitQueue", + "fullQueueCh", + "fs", + } +} + +func (fd *DeviceFD) beforeSave() {} + +// +checklocksignore +func (fd *DeviceFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + var fullQueueChValue int + fullQueueChValue = fd.saveFullQueueCh() + stateSinkObject.SaveValue(12, fullQueueChValue) + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.NoLockFD) + stateSinkObject.Save(4, &fd.nextOpID) + stateSinkObject.Save(5, &fd.queue) + stateSinkObject.Save(6, &fd.numActiveRequests) + stateSinkObject.Save(7, &fd.completions) + stateSinkObject.Save(8, &fd.writeCursor) + stateSinkObject.Save(9, &fd.writeBuf) + stateSinkObject.Save(10, &fd.writeCursorFR) + stateSinkObject.Save(11, &fd.waitQueue) + stateSinkObject.Save(13, &fd.fs) +} + +func (fd *DeviceFD) afterLoad() {} + +// +checklocksignore +func (fd *DeviceFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.NoLockFD) + stateSourceObject.Load(4, &fd.nextOpID) + stateSourceObject.Load(5, &fd.queue) + stateSourceObject.Load(6, &fd.numActiveRequests) + stateSourceObject.Load(7, &fd.completions) + stateSourceObject.Load(8, &fd.writeCursor) + stateSourceObject.Load(9, &fd.writeBuf) + stateSourceObject.Load(10, &fd.writeCursorFR) + stateSourceObject.Load(11, &fd.waitQueue) + stateSourceObject.Load(13, &fd.fs) + stateSourceObject.LoadValue(12, new(int), func(y interface{}) { fd.loadFullQueueCh(y.(int)) }) +} + +func (fsType *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.FilesystemType" +} + +func (fsType *FilesystemType) StateFields() []string { + return []string{} +} + +func (fsType *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fsType *FilesystemType) StateSave(stateSinkObject state.Sink) { + fsType.beforeSave() +} + +func (fsType *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fsType *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (f *filesystemOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.filesystemOptions" +} + +func (f *filesystemOptions) StateFields() []string { + return []string{ + "mopts", + "uid", + "gid", + "rootMode", + "maxActiveRequests", + "maxRead", + "defaultPermissions", + "allowOther", + } +} + +func (f *filesystemOptions) beforeSave() {} + +// +checklocksignore +func (f *filesystemOptions) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.mopts) + stateSinkObject.Save(1, &f.uid) + stateSinkObject.Save(2, &f.gid) + stateSinkObject.Save(3, &f.rootMode) + stateSinkObject.Save(4, &f.maxActiveRequests) + stateSinkObject.Save(5, &f.maxRead) + stateSinkObject.Save(6, &f.defaultPermissions) + stateSinkObject.Save(7, &f.allowOther) +} + +func (f *filesystemOptions) afterLoad() {} + +// +checklocksignore +func (f *filesystemOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.mopts) + stateSourceObject.Load(1, &f.uid) + stateSourceObject.Load(2, &f.gid) + stateSourceObject.Load(3, &f.rootMode) + stateSourceObject.Load(4, &f.maxActiveRequests) + stateSourceObject.Load(5, &f.maxRead) + stateSourceObject.Load(6, &f.defaultPermissions) + stateSourceObject.Load(7, &f.allowOther) +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + "conn", + "opts", + "umounted", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) + stateSinkObject.Save(2, &fs.conn) + stateSinkObject.Save(3, &fs.opts) + stateSinkObject.Save(4, &fs.umounted) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) + stateSourceObject.Load(2, &fs.conn) + stateSourceObject.Load(3, &fs.opts) + stateSourceObject.Load(4, &fs.umounted) +} + +func (i *inode) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.inode" +} + +func (i *inode) StateFields() []string { + return []string{ + "inodeRefs", + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "OrderedChildren", + "fs", + "metadataMu", + "nodeID", + "locks", + "size", + "attributeVersion", + "attributeTime", + "version", + "link", + } +} + +func (i *inode) beforeSave() {} + +// +checklocksignore +func (i *inode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.inodeRefs) + stateSinkObject.Save(1, &i.InodeAlwaysValid) + stateSinkObject.Save(2, &i.InodeAttrs) + stateSinkObject.Save(3, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(4, &i.InodeNotSymlink) + stateSinkObject.Save(5, &i.OrderedChildren) + stateSinkObject.Save(6, &i.fs) + stateSinkObject.Save(7, &i.metadataMu) + stateSinkObject.Save(8, &i.nodeID) + stateSinkObject.Save(9, &i.locks) + stateSinkObject.Save(10, &i.size) + stateSinkObject.Save(11, &i.attributeVersion) + stateSinkObject.Save(12, &i.attributeTime) + stateSinkObject.Save(13, &i.version) + stateSinkObject.Save(14, &i.link) +} + +func (i *inode) afterLoad() {} + +// +checklocksignore +func (i *inode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.inodeRefs) + stateSourceObject.Load(1, &i.InodeAlwaysValid) + stateSourceObject.Load(2, &i.InodeAttrs) + stateSourceObject.Load(3, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(4, &i.InodeNotSymlink) + stateSourceObject.Load(5, &i.OrderedChildren) + stateSourceObject.Load(6, &i.fs) + stateSourceObject.Load(7, &i.metadataMu) + stateSourceObject.Load(8, &i.nodeID) + stateSourceObject.Load(9, &i.locks) + stateSourceObject.Load(10, &i.size) + stateSourceObject.Load(11, &i.attributeVersion) + stateSourceObject.Load(12, &i.attributeTime) + stateSourceObject.Load(13, &i.version) + stateSourceObject.Load(14, &i.link) +} + +func (r *inodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.inodeRefs" +} + +func (r *inodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *inodeRefs) beforeSave() {} + +// +checklocksignore +func (r *inodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *inodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (l *requestList) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.requestList" +} + +func (l *requestList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *requestList) beforeSave() {} + +// +checklocksignore +func (l *requestList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *requestList) afterLoad() {} + +// +checklocksignore +func (l *requestList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *requestEntry) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.requestEntry" +} + +func (e *requestEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *requestEntry) beforeSave() {} + +// +checklocksignore +func (e *requestEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *requestEntry) afterLoad() {} + +// +checklocksignore +func (e *requestEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (r *Request) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.Request" +} + +func (r *Request) StateFields() []string { + return []string{ + "requestEntry", + "id", + "hdr", + "data", + "payload", + "async", + "noReply", + } +} + +func (r *Request) beforeSave() {} + +// +checklocksignore +func (r *Request) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.requestEntry) + stateSinkObject.Save(1, &r.id) + stateSinkObject.Save(2, &r.hdr) + stateSinkObject.Save(3, &r.data) + stateSinkObject.Save(4, &r.payload) + stateSinkObject.Save(5, &r.async) + stateSinkObject.Save(6, &r.noReply) +} + +func (r *Request) afterLoad() {} + +// +checklocksignore +func (r *Request) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.requestEntry) + stateSourceObject.Load(1, &r.id) + stateSourceObject.Load(2, &r.hdr) + stateSourceObject.Load(3, &r.data) + stateSourceObject.Load(4, &r.payload) + stateSourceObject.Load(5, &r.async) + stateSourceObject.Load(6, &r.noReply) +} + +func (f *futureResponse) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.futureResponse" +} + +func (f *futureResponse) StateFields() []string { + return []string{ + "opcode", + "ch", + "hdr", + "data", + "async", + } +} + +func (f *futureResponse) beforeSave() {} + +// +checklocksignore +func (f *futureResponse) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.opcode) + stateSinkObject.Save(1, &f.ch) + stateSinkObject.Save(2, &f.hdr) + stateSinkObject.Save(3, &f.data) + stateSinkObject.Save(4, &f.async) +} + +func (f *futureResponse) afterLoad() {} + +// +checklocksignore +func (f *futureResponse) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.opcode) + stateSourceObject.Load(1, &f.ch) + stateSourceObject.Load(2, &f.hdr) + stateSourceObject.Load(3, &f.data) + stateSourceObject.Load(4, &f.async) +} + +func (r *Response) StateTypeName() string { + return "pkg/sentry/fsimpl/fuse.Response" +} + +func (r *Response) StateFields() []string { + return []string{ + "opcode", + "hdr", + "data", + } +} + +func (r *Response) beforeSave() {} + +// +checklocksignore +func (r *Response) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.opcode) + stateSinkObject.Save(1, &r.hdr) + stateSinkObject.Save(2, &r.data) +} + +func (r *Response) afterLoad() {} + +// +checklocksignore +func (r *Response) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.opcode) + stateSourceObject.Load(1, &r.hdr) + stateSourceObject.Load(2, &r.data) +} + +func init() { + state.Register((*connection)(nil)) + state.Register((*fuseDevice)(nil)) + state.Register((*DeviceFD)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*filesystemOptions)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*inode)(nil)) + state.Register((*inodeRefs)(nil)) + state.Register((*requestList)(nil)) + state.Register((*requestEntry)(nil)) + state.Register((*Request)(nil)) + state.Register((*futureResponse)(nil)) + state.Register((*Response)(nil)) +} diff --git a/pkg/sentry/fsimpl/fuse/inode_refs.go b/pkg/sentry/fsimpl/fuse/inode_refs.go new file mode 100644 index 000000000..74489cf5e --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/inode_refs.go @@ -0,0 +1,140 @@ +package fuse + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const inodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var inodeobj *inode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type inodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *inodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *inodeRefs) RefType() string { + return fmt.Sprintf("%T", inodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *inodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *inodeRefs) LogRefs() bool { + return inodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *inodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *inodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if inodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *inodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if inodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *inodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if inodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *inodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/fuse/request_list.go b/pkg/sentry/fsimpl/fuse/request_list.go new file mode 100644 index 000000000..060ac4a3f --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/request_list.go @@ -0,0 +1,221 @@ +package fuse + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type requestElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (requestElementMapper) linkerFor(elem *Request) *Request { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type requestList struct { + head *Request + tail *Request +} + +// Reset resets list l to the empty state. +func (l *requestList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *requestList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *requestList) Front() *Request { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *requestList) Back() *Request { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *requestList) Len() (count int) { + for e := l.Front(); e != nil; e = (requestElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *requestList) PushFront(e *Request) { + linker := requestElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + requestElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *requestList) PushBack(e *Request) { + linker := requestElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + requestElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *requestList) PushBackList(m *requestList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + requestElementMapper{}.linkerFor(l.tail).SetNext(m.head) + requestElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *requestList) InsertAfter(b, e *Request) { + bLinker := requestElementMapper{}.linkerFor(b) + eLinker := requestElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + requestElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *requestList) InsertBefore(a, e *Request) { + aLinker := requestElementMapper{}.linkerFor(a) + eLinker := requestElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + requestElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *requestList) Remove(e *Request) { + linker := requestElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + requestElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + requestElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type requestEntry struct { + next *Request + prev *Request +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *requestEntry) Next() *Request { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *requestEntry) Prev() *Request { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *requestEntry) SetNext(elem *Request) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *requestEntry) SetPrev(elem *Request) { + e.prev = elem +} diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go deleted file mode 100644 index b0bab0066..000000000 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fuse - -import ( - "io" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - - "gvisor.dev/gvisor/pkg/hostarch" -) - -func setup(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Error creating kernel: %v", err) - } - - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - - k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserList: true, - AllowUserMount: true, - }) - - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) - if err != nil { - t.Fatalf("NewMountNamespace(): %v", err) - } - - return testutil.NewSystem(ctx, t, k.VFS(), mntns) -} - -// newTestConnection creates a fuse connection that the sentry can communicate with -// and the FD for the server to communicate with. -func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { - fuseDev := &DeviceFD{} - - vd := system.VFS.NewAnonVirtualDentry("fuse") - defer vd.DecRef(system.Ctx) - if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { - return nil, nil, err - } - - fsopts := filesystemOptions{ - maxActiveRequests: maxActiveRequests, - } - fs, err := newFUSEFilesystem(system.Ctx, system.VFS, &FilesystemType{}, fuseDev, 0, &fsopts) - if err != nil { - return nil, nil, err - } - return fs.conn, &fuseDev.vfsfd, nil -} - -type testPayload struct { - marshal.StubMarshallable - data uint32 -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *testPayload) SizeBytes() int { - return 4 -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *testPayload) MarshalBytes(dst []byte) { - hostarch.ByteOrder.PutUint32(dst[:4], t.data) -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])} -} - -// Packed implements marshal.Marshallable.Packed. -func (t *testPayload) Packed() bool { - return true -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (t *testPayload) MarshalUnsafe(dst []byte) { - t.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (t *testPayload) UnmarshalUnsafe(src []byte) { - t.UnmarshalBytes(src) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { - panic("not implemented") -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// WriteTo implements io.WriterTo.WriteTo. -func (t *testPayload) WriteTo(w io.Writer) (int64, error) { - panic("not implemented") -} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD deleted file mode 100644 index 4244f2cf5..000000000 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ /dev/null @@ -1,98 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dentry_list", - out = "dentry_list.go", - package = "gofer", - prefix = "dentry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*dentry", - "Linker": "*dentry", - }, -) - -go_template_instance( - name = "fstree", - out = "fstree.go", - package = "gofer", - prefix = "generic", - template = "//pkg/sentry/vfs/genericfstree:generic_fstree", - types = { - "Dentry": "dentry", - }, -) - -go_library( - name = "gofer", - srcs = [ - "dentry_list.go", - "directory.go", - "filesystem.go", - "fstree.go", - "gofer.go", - "handle.go", - "host_named_pipe.go", - "p9file.go", - "regular_file.go", - "revalidate.go", - "save_restore.go", - "socket.go", - "special_file.go", - "symlink.go", - "time.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/log", - "//pkg/metric", - "//pkg/p9", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/host", - "//pkg/sentry/fsmetric", - "//pkg/sentry/hostfd", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/unet", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "gofer_test", - srcs = ["gofer_test.go"], - library = ":gofer", - deps = [ - "//pkg/p9", - "//pkg/sentry/contexttest", - "//pkg/sentry/pgalloc", - ], -) diff --git a/pkg/sentry/fsimpl/gofer/dentry_list.go b/pkg/sentry/fsimpl/gofer/dentry_list.go new file mode 100644 index 000000000..2e43b8e02 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/dentry_list.go @@ -0,0 +1,221 @@ +package gofer + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type dentryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (dentryElementMapper) linkerFor(elem *dentry) *dentry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type dentryList struct { + head *dentry + tail *dentry +} + +// Reset resets list l to the empty state. +func (l *dentryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *dentryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *dentryList) Front() *dentry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *dentryList) Back() *dentry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *dentryList) Len() (count int) { + for e := l.Front(); e != nil; e = (dentryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *dentryList) PushFront(e *dentry) { + linker := dentryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + dentryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *dentryList) PushBack(e *dentry) { + linker := dentryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + dentryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *dentryList) PushBackList(m *dentryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + dentryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + dentryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *dentryList) InsertAfter(b, e *dentry) { + bLinker := dentryElementMapper{}.linkerFor(b) + eLinker := dentryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + dentryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *dentryList) InsertBefore(a, e *dentry) { + aLinker := dentryElementMapper{}.linkerFor(a) + eLinker := dentryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + dentryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *dentryList) Remove(e *dentry) { + linker := dentryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + dentryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + dentryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type dentryEntry struct { + next *dentry + prev *dentry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *dentryEntry) Next() *dentry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *dentryEntry) Prev() *dentry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *dentryEntry) SetNext(elem *dentry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *dentryEntry) SetPrev(elem *dentry) { + e.prev = elem +} diff --git a/pkg/sentry/fsimpl/gofer/fstree.go b/pkg/sentry/fsimpl/gofer/fstree.go new file mode 100644 index 000000000..6e43d4a4b --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/fstree.go @@ -0,0 +1,55 @@ +package gofer + +import ( + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is +// either d2's parent or an ancestor of d2's parent. +func genericIsAncestorDentry(d, d2 *dentry) bool { + for d2 != nil { + if d2.parent == d { + return true + } + if d2.parent == d2 { + return false + } + d2 = d2.parent + } + return false +} + +// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. +func genericParentOrSelf(d *dentry) *dentry { + if d.parent != nil { + return d.parent + } + return d +} + +// PrependPath is a generic implementation of FilesystemImpl.PrependPath(). +func genericPrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *dentry, b *fspath.Builder) error { + for { + if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { + return vfs.PrependPathAtVFSRootError{} + } + if mnt != nil && &d.vfsd == mnt.Root() { + return nil + } + if d.parent == nil { + return vfs.PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func genericDebugPathname(d *dentry) string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/fsimpl/gofer/gofer_state_autogen.go b/pkg/sentry/fsimpl/gofer/gofer_state_autogen.go new file mode 100644 index 000000000..11635f249 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/gofer_state_autogen.go @@ -0,0 +1,608 @@ +// automatically generated by stateify. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *dentryList) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.dentryList" +} + +func (l *dentryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *dentryList) beforeSave() {} + +// +checklocksignore +func (l *dentryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *dentryList) afterLoad() {} + +// +checklocksignore +func (l *dentryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *dentryEntry) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.dentryEntry" +} + +func (e *dentryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *dentryEntry) beforeSave() {} + +// +checklocksignore +func (e *dentryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *dentryEntry) afterLoad() {} + +// +checklocksignore +func (e *dentryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (fd *directoryFD) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.directoryFD" +} + +func (fd *directoryFD) StateFields() []string { + return []string{ + "fileDescription", + "DirectoryFileDescriptionDefaultImpl", + "off", + "dirents", + } +} + +func (fd *directoryFD) beforeSave() {} + +// +checklocksignore +func (fd *directoryFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.off) + stateSinkObject.Save(3, &fd.dirents) +} + +func (fd *directoryFD) afterLoad() {} + +// +checklocksignore +func (fd *directoryFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.off) + stateSourceObject.Load(3, &fd.dirents) +} + +func (fstype *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.FilesystemType" +} + +func (fstype *FilesystemType) StateFields() []string { + return []string{} +} + +func (fstype *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fstype *FilesystemType) StateSave(stateSinkObject state.Sink) { + fstype.beforeSave() +} + +func (fstype *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fstype *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "vfsfs", + "mfp", + "opts", + "iopts", + "clock", + "devMinor", + "root", + "cachedDentries", + "cachedDentriesLen", + "syncableDentries", + "specialFileFDs", + "lastIno", + "savedDentryRW", + "released", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.vfsfs) + stateSinkObject.Save(1, &fs.mfp) + stateSinkObject.Save(2, &fs.opts) + stateSinkObject.Save(3, &fs.iopts) + stateSinkObject.Save(4, &fs.clock) + stateSinkObject.Save(5, &fs.devMinor) + stateSinkObject.Save(6, &fs.root) + stateSinkObject.Save(7, &fs.cachedDentries) + stateSinkObject.Save(8, &fs.cachedDentriesLen) + stateSinkObject.Save(9, &fs.syncableDentries) + stateSinkObject.Save(10, &fs.specialFileFDs) + stateSinkObject.Save(11, &fs.lastIno) + stateSinkObject.Save(12, &fs.savedDentryRW) + stateSinkObject.Save(13, &fs.released) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.vfsfs) + stateSourceObject.Load(1, &fs.mfp) + stateSourceObject.Load(2, &fs.opts) + stateSourceObject.Load(3, &fs.iopts) + stateSourceObject.Load(4, &fs.clock) + stateSourceObject.Load(5, &fs.devMinor) + stateSourceObject.Load(6, &fs.root) + stateSourceObject.Load(7, &fs.cachedDentries) + stateSourceObject.Load(8, &fs.cachedDentriesLen) + stateSourceObject.Load(9, &fs.syncableDentries) + stateSourceObject.Load(10, &fs.specialFileFDs) + stateSourceObject.Load(11, &fs.lastIno) + stateSourceObject.Load(12, &fs.savedDentryRW) + stateSourceObject.Load(13, &fs.released) +} + +func (f *filesystemOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.filesystemOptions" +} + +func (f *filesystemOptions) StateFields() []string { + return []string{ + "fd", + "aname", + "interop", + "dfltuid", + "dfltgid", + "msize", + "version", + "maxCachedDentries", + "forcePageCache", + "limitHostFDTranslation", + "overlayfsStaleRead", + "regularFilesUseSpecialFileFD", + } +} + +func (f *filesystemOptions) beforeSave() {} + +// +checklocksignore +func (f *filesystemOptions) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.fd) + stateSinkObject.Save(1, &f.aname) + stateSinkObject.Save(2, &f.interop) + stateSinkObject.Save(3, &f.dfltuid) + stateSinkObject.Save(4, &f.dfltgid) + stateSinkObject.Save(5, &f.msize) + stateSinkObject.Save(6, &f.version) + stateSinkObject.Save(7, &f.maxCachedDentries) + stateSinkObject.Save(8, &f.forcePageCache) + stateSinkObject.Save(9, &f.limitHostFDTranslation) + stateSinkObject.Save(10, &f.overlayfsStaleRead) + stateSinkObject.Save(11, &f.regularFilesUseSpecialFileFD) +} + +func (f *filesystemOptions) afterLoad() {} + +// +checklocksignore +func (f *filesystemOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.fd) + stateSourceObject.Load(1, &f.aname) + stateSourceObject.Load(2, &f.interop) + stateSourceObject.Load(3, &f.dfltuid) + stateSourceObject.Load(4, &f.dfltgid) + stateSourceObject.Load(5, &f.msize) + stateSourceObject.Load(6, &f.version) + stateSourceObject.Load(7, &f.maxCachedDentries) + stateSourceObject.Load(8, &f.forcePageCache) + stateSourceObject.Load(9, &f.limitHostFDTranslation) + stateSourceObject.Load(10, &f.overlayfsStaleRead) + stateSourceObject.Load(11, &f.regularFilesUseSpecialFileFD) +} + +func (i *InteropMode) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.InteropMode" +} + +func (i *InteropMode) StateFields() []string { + return nil +} + +func (i *InternalFilesystemOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.InternalFilesystemOptions" +} + +func (i *InternalFilesystemOptions) StateFields() []string { + return []string{ + "UniqueID", + "LeakConnection", + "OpenSocketsByConnecting", + } +} + +func (i *InternalFilesystemOptions) beforeSave() {} + +// +checklocksignore +func (i *InternalFilesystemOptions) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.UniqueID) + stateSinkObject.Save(1, &i.LeakConnection) + stateSinkObject.Save(2, &i.OpenSocketsByConnecting) +} + +func (i *InternalFilesystemOptions) afterLoad() {} + +// +checklocksignore +func (i *InternalFilesystemOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.UniqueID) + stateSourceObject.Load(1, &i.LeakConnection) + stateSourceObject.Load(2, &i.OpenSocketsByConnecting) +} + +func (d *dentry) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.dentry" +} + +func (d *dentry) StateFields() []string { + return []string{ + "vfsd", + "refs", + "fs", + "parent", + "name", + "qidPath", + "deleted", + "cached", + "dentryEntry", + "children", + "syntheticChildren", + "dirents", + "ino", + "mode", + "uid", + "gid", + "blockSize", + "atime", + "mtime", + "ctime", + "btime", + "size", + "atimeDirty", + "mtimeDirty", + "nlink", + "mappings", + "cache", + "dirty", + "pf", + "haveTarget", + "target", + "endpoint", + "pipe", + "locks", + "watches", + } +} + +// +checklocksignore +func (d *dentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.vfsd) + stateSinkObject.Save(1, &d.refs) + stateSinkObject.Save(2, &d.fs) + stateSinkObject.Save(3, &d.parent) + stateSinkObject.Save(4, &d.name) + stateSinkObject.Save(5, &d.qidPath) + stateSinkObject.Save(6, &d.deleted) + stateSinkObject.Save(7, &d.cached) + stateSinkObject.Save(8, &d.dentryEntry) + stateSinkObject.Save(9, &d.children) + stateSinkObject.Save(10, &d.syntheticChildren) + stateSinkObject.Save(11, &d.dirents) + stateSinkObject.Save(12, &d.ino) + stateSinkObject.Save(13, &d.mode) + stateSinkObject.Save(14, &d.uid) + stateSinkObject.Save(15, &d.gid) + stateSinkObject.Save(16, &d.blockSize) + stateSinkObject.Save(17, &d.atime) + stateSinkObject.Save(18, &d.mtime) + stateSinkObject.Save(19, &d.ctime) + stateSinkObject.Save(20, &d.btime) + stateSinkObject.Save(21, &d.size) + stateSinkObject.Save(22, &d.atimeDirty) + stateSinkObject.Save(23, &d.mtimeDirty) + stateSinkObject.Save(24, &d.nlink) + stateSinkObject.Save(25, &d.mappings) + stateSinkObject.Save(26, &d.cache) + stateSinkObject.Save(27, &d.dirty) + stateSinkObject.Save(28, &d.pf) + stateSinkObject.Save(29, &d.haveTarget) + stateSinkObject.Save(30, &d.target) + stateSinkObject.Save(31, &d.endpoint) + stateSinkObject.Save(32, &d.pipe) + stateSinkObject.Save(33, &d.locks) + stateSinkObject.Save(34, &d.watches) +} + +// +checklocksignore +func (d *dentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.vfsd) + stateSourceObject.Load(1, &d.refs) + stateSourceObject.Load(2, &d.fs) + stateSourceObject.Load(3, &d.parent) + stateSourceObject.Load(4, &d.name) + stateSourceObject.Load(5, &d.qidPath) + stateSourceObject.Load(6, &d.deleted) + stateSourceObject.Load(7, &d.cached) + stateSourceObject.Load(8, &d.dentryEntry) + stateSourceObject.Load(9, &d.children) + stateSourceObject.Load(10, &d.syntheticChildren) + stateSourceObject.Load(11, &d.dirents) + stateSourceObject.Load(12, &d.ino) + stateSourceObject.Load(13, &d.mode) + stateSourceObject.Load(14, &d.uid) + stateSourceObject.Load(15, &d.gid) + stateSourceObject.Load(16, &d.blockSize) + stateSourceObject.Load(17, &d.atime) + stateSourceObject.Load(18, &d.mtime) + stateSourceObject.Load(19, &d.ctime) + stateSourceObject.Load(20, &d.btime) + stateSourceObject.Load(21, &d.size) + stateSourceObject.Load(22, &d.atimeDirty) + stateSourceObject.Load(23, &d.mtimeDirty) + stateSourceObject.Load(24, &d.nlink) + stateSourceObject.Load(25, &d.mappings) + stateSourceObject.Load(26, &d.cache) + stateSourceObject.Load(27, &d.dirty) + stateSourceObject.Load(28, &d.pf) + stateSourceObject.Load(29, &d.haveTarget) + stateSourceObject.Load(30, &d.target) + stateSourceObject.Load(31, &d.endpoint) + stateSourceObject.Load(32, &d.pipe) + stateSourceObject.Load(33, &d.locks) + stateSourceObject.Load(34, &d.watches) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (fd *fileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.fileDescription" +} + +func (fd *fileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + } +} + +func (fd *fileDescription) beforeSave() {} + +// +checklocksignore +func (fd *fileDescription) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.LockFD) +} + +func (fd *fileDescription) afterLoad() {} + +// +checklocksignore +func (fd *fileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.LockFD) +} + +func (fd *regularFileFD) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.regularFileFD" +} + +func (fd *regularFileFD) StateFields() []string { + return []string{ + "fileDescription", + "off", + } +} + +func (fd *regularFileFD) beforeSave() {} + +// +checklocksignore +func (fd *regularFileFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.off) +} + +func (fd *regularFileFD) afterLoad() {} + +// +checklocksignore +func (fd *regularFileFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.off) +} + +func (d *dentryPlatformFile) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.dentryPlatformFile" +} + +func (d *dentryPlatformFile) StateFields() []string { + return []string{ + "dentry", + "fdRefs", + "hostFileMapper", + } +} + +func (d *dentryPlatformFile) beforeSave() {} + +// +checklocksignore +func (d *dentryPlatformFile) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.dentry) + stateSinkObject.Save(1, &d.fdRefs) + stateSinkObject.Save(2, &d.hostFileMapper) +} + +// +checklocksignore +func (d *dentryPlatformFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.dentry) + stateSourceObject.Load(1, &d.fdRefs) + stateSourceObject.Load(2, &d.hostFileMapper) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (s *savedDentryRW) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.savedDentryRW" +} + +func (s *savedDentryRW) StateFields() []string { + return []string{ + "read", + "write", + } +} + +func (s *savedDentryRW) beforeSave() {} + +// +checklocksignore +func (s *savedDentryRW) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.read) + stateSinkObject.Save(1, &s.write) +} + +func (s *savedDentryRW) afterLoad() {} + +// +checklocksignore +func (s *savedDentryRW) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.read) + stateSourceObject.Load(1, &s.write) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "dentry", + "path", + } +} + +func (e *endpoint) beforeSave() {} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.dentry) + stateSinkObject.Save(1, &e.path) +} + +func (e *endpoint) afterLoad() {} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.dentry) + stateSourceObject.Load(1, &e.path) +} + +func (fd *specialFileFD) StateTypeName() string { + return "pkg/sentry/fsimpl/gofer.specialFileFD" +} + +func (fd *specialFileFD) StateFields() []string { + return []string{ + "fileDescription", + "isRegularFile", + "seekable", + "queue", + "off", + "haveBuf", + "buf", + } +} + +func (fd *specialFileFD) beforeSave() {} + +// +checklocksignore +func (fd *specialFileFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.isRegularFile) + stateSinkObject.Save(2, &fd.seekable) + stateSinkObject.Save(3, &fd.queue) + stateSinkObject.Save(4, &fd.off) + stateSinkObject.Save(5, &fd.haveBuf) + stateSinkObject.Save(6, &fd.buf) +} + +// +checklocksignore +func (fd *specialFileFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.isRegularFile) + stateSourceObject.Load(2, &fd.seekable) + stateSourceObject.Load(3, &fd.queue) + stateSourceObject.Load(4, &fd.off) + stateSourceObject.Load(5, &fd.haveBuf) + stateSourceObject.Load(6, &fd.buf) + stateSourceObject.AfterLoad(fd.afterLoad) +} + +func init() { + state.Register((*dentryList)(nil)) + state.Register((*dentryEntry)(nil)) + state.Register((*directoryFD)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*filesystemOptions)(nil)) + state.Register((*InteropMode)(nil)) + state.Register((*InternalFilesystemOptions)(nil)) + state.Register((*dentry)(nil)) + state.Register((*fileDescription)(nil)) + state.Register((*regularFileFD)(nil)) + state.Register((*dentryPlatformFile)(nil)) + state.Register((*savedDentryRW)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*specialFileFD)(nil)) +} diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go deleted file mode 100644 index 806392d50..000000000 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "sync/atomic" - "testing" - - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" -) - -func TestDestroyIdempotent(t *testing.T) { - ctx := contexttest.Context(t) - fs := filesystem{ - mfp: pgalloc.MemoryFileProviderFromContext(ctx), - opts: filesystemOptions{ - // Test relies on no dentry being held in the cache. - maxCachedDentries: 0, - }, - syncableDentries: make(map[*dentry]struct{}), - inoByQIDPath: make(map[uint64]uint64), - } - - attr := &p9.Attr{ - Mode: p9.ModeRegular, - } - mask := p9.AttrMask{ - Mode: true, - Size: true, - } - parent, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr) - if err != nil { - t.Fatalf("fs.newDentry(): %v", err) - } - - child, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr) - if err != nil { - t.Fatalf("fs.newDentry(): %v", err) - } - parent.cacheNewChildLocked(child, "child") - - fs.renameMu.Lock() - defer fs.renameMu.Unlock() - child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) - if got := atomic.LoadInt64(&child.refs); got != -1 { - t.Fatalf("child.refs=%d, want: -1", got) - } - // Parent will also be destroyed when child reference is removed. - if got := atomic.LoadInt64(&parent.refs); got != -1 { - t.Fatalf("parent.refs=%d, want: -1", got) - } - child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) - child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) -} diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD deleted file mode 100644 index 180a35583..000000000 --- a/pkg/sentry/fsimpl/host/BUILD +++ /dev/null @@ -1,79 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "inode_refs", - out = "inode_refs.go", - package = "host", - prefix = "inode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "inode", - }, -) - -go_template_instance( - name = "connected_endpoint_refs", - out = "connected_endpoint_refs.go", - package = "host", - prefix = "ConnectedEndpoint", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "ConnectedEndpoint", - }, -) - -go_library( - name = "host", - srcs = [ - "connected_endpoint_refs.go", - "control.go", - "host.go", - "inode_refs.go", - "ioctl_unsafe.go", - "save_restore.go", - "socket.go", - "socket_iovec.go", - "socket_unsafe.go", - "tty.go", - "util.go", - "util_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fdnotifier", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal/primitive", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/hostfd", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sentry/uniqueid", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/unet", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/host/connected_endpoint_refs.go b/pkg/sentry/fsimpl/host/connected_endpoint_refs.go new file mode 100644 index 000000000..c0a87f656 --- /dev/null +++ b/pkg/sentry/fsimpl/host/connected_endpoint_refs.go @@ -0,0 +1,140 @@ +package host + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const ConnectedEndpointenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var ConnectedEndpointobj *ConnectedEndpoint + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type ConnectedEndpointRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *ConnectedEndpointRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *ConnectedEndpointRefs) RefType() string { + return fmt.Sprintf("%T", ConnectedEndpointobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *ConnectedEndpointRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *ConnectedEndpointRefs) LogRefs() bool { + return ConnectedEndpointenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *ConnectedEndpointRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *ConnectedEndpointRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if ConnectedEndpointenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *ConnectedEndpointRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if ConnectedEndpointenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *ConnectedEndpointRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if ConnectedEndpointenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *ConnectedEndpointRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/host/host_state_autogen.go b/pkg/sentry/fsimpl/host/host_state_autogen.go new file mode 100644 index 000000000..607474165 --- /dev/null +++ b/pkg/sentry/fsimpl/host/host_state_autogen.go @@ -0,0 +1,327 @@ +// automatically generated by stateify. + +package host + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *ConnectedEndpointRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/host.ConnectedEndpointRefs" +} + +func (r *ConnectedEndpointRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *ConnectedEndpointRefs) beforeSave() {} + +// +checklocksignore +func (r *ConnectedEndpointRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *ConnectedEndpointRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (v *virtualOwner) StateTypeName() string { + return "pkg/sentry/fsimpl/host.virtualOwner" +} + +func (v *virtualOwner) StateFields() []string { + return []string{ + "enabled", + "uid", + "gid", + "mode", + } +} + +func (v *virtualOwner) beforeSave() {} + +// +checklocksignore +func (v *virtualOwner) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + stateSinkObject.Save(0, &v.enabled) + stateSinkObject.Save(1, &v.uid) + stateSinkObject.Save(2, &v.gid) + stateSinkObject.Save(3, &v.mode) +} + +func (v *virtualOwner) afterLoad() {} + +// +checklocksignore +func (v *virtualOwner) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.enabled) + stateSourceObject.Load(1, &v.uid) + stateSourceObject.Load(2, &v.gid) + stateSourceObject.Load(3, &v.mode) +} + +func (i *inode) StateTypeName() string { + return "pkg/sentry/fsimpl/host.inode" +} + +func (i *inode) StateFields() []string { + return []string{ + "InodeNoStatFS", + "InodeNotDirectory", + "InodeNotSymlink", + "CachedMappable", + "InodeTemporary", + "locks", + "inodeRefs", + "hostFD", + "ino", + "ftype", + "mayBlock", + "seekable", + "isTTY", + "savable", + "queue", + "virtualOwner", + "haveBuf", + "buf", + } +} + +// +checklocksignore +func (i *inode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeNoStatFS) + stateSinkObject.Save(1, &i.InodeNotDirectory) + stateSinkObject.Save(2, &i.InodeNotSymlink) + stateSinkObject.Save(3, &i.CachedMappable) + stateSinkObject.Save(4, &i.InodeTemporary) + stateSinkObject.Save(5, &i.locks) + stateSinkObject.Save(6, &i.inodeRefs) + stateSinkObject.Save(7, &i.hostFD) + stateSinkObject.Save(8, &i.ino) + stateSinkObject.Save(9, &i.ftype) + stateSinkObject.Save(10, &i.mayBlock) + stateSinkObject.Save(11, &i.seekable) + stateSinkObject.Save(12, &i.isTTY) + stateSinkObject.Save(13, &i.savable) + stateSinkObject.Save(14, &i.queue) + stateSinkObject.Save(15, &i.virtualOwner) + stateSinkObject.Save(16, &i.haveBuf) + stateSinkObject.Save(17, &i.buf) +} + +// +checklocksignore +func (i *inode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeNoStatFS) + stateSourceObject.Load(1, &i.InodeNotDirectory) + stateSourceObject.Load(2, &i.InodeNotSymlink) + stateSourceObject.Load(3, &i.CachedMappable) + stateSourceObject.Load(4, &i.InodeTemporary) + stateSourceObject.Load(5, &i.locks) + stateSourceObject.Load(6, &i.inodeRefs) + stateSourceObject.Load(7, &i.hostFD) + stateSourceObject.Load(8, &i.ino) + stateSourceObject.Load(9, &i.ftype) + stateSourceObject.Load(10, &i.mayBlock) + stateSourceObject.Load(11, &i.seekable) + stateSourceObject.Load(12, &i.isTTY) + stateSourceObject.Load(13, &i.savable) + stateSourceObject.Load(14, &i.queue) + stateSourceObject.Load(15, &i.virtualOwner) + stateSourceObject.Load(16, &i.haveBuf) + stateSourceObject.Load(17, &i.buf) + stateSourceObject.AfterLoad(i.afterLoad) +} + +func (f *filesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/host.filesystemType" +} + +func (f *filesystemType) StateFields() []string { + return []string{} +} + +func (f *filesystemType) beforeSave() {} + +// +checklocksignore +func (f *filesystemType) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystemType) afterLoad() {} + +// +checklocksignore +func (f *filesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/host.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (f *fileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/host.fileDescription" +} + +func (f *fileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + "inode", + "offset", + } +} + +func (f *fileDescription) beforeSave() {} + +// +checklocksignore +func (f *fileDescription) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.vfsfd) + stateSinkObject.Save(1, &f.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &f.LockFD) + stateSinkObject.Save(3, &f.inode) + stateSinkObject.Save(4, &f.offset) +} + +func (f *fileDescription) afterLoad() {} + +// +checklocksignore +func (f *fileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.vfsfd) + stateSourceObject.Load(1, &f.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &f.LockFD) + stateSourceObject.Load(3, &f.inode) + stateSourceObject.Load(4, &f.offset) +} + +func (r *inodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/host.inodeRefs" +} + +func (r *inodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *inodeRefs) beforeSave() {} + +// +checklocksignore +func (r *inodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *inodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (c *ConnectedEndpoint) StateTypeName() string { + return "pkg/sentry/fsimpl/host.ConnectedEndpoint" +} + +func (c *ConnectedEndpoint) StateFields() []string { + return []string{ + "ConnectedEndpointRefs", + "fd", + "addr", + "stype", + } +} + +func (c *ConnectedEndpoint) beforeSave() {} + +// +checklocksignore +func (c *ConnectedEndpoint) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.ConnectedEndpointRefs) + stateSinkObject.Save(1, &c.fd) + stateSinkObject.Save(2, &c.addr) + stateSinkObject.Save(3, &c.stype) +} + +// +checklocksignore +func (c *ConnectedEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.ConnectedEndpointRefs) + stateSourceObject.Load(1, &c.fd) + stateSourceObject.Load(2, &c.addr) + stateSourceObject.Load(3, &c.stype) + stateSourceObject.AfterLoad(c.afterLoad) +} + +func (t *TTYFileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/host.TTYFileDescription" +} + +func (t *TTYFileDescription) StateFields() []string { + return []string{ + "fileDescription", + "session", + "fgProcessGroup", + "termios", + } +} + +func (t *TTYFileDescription) beforeSave() {} + +// +checklocksignore +func (t *TTYFileDescription) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.fileDescription) + stateSinkObject.Save(1, &t.session) + stateSinkObject.Save(2, &t.fgProcessGroup) + stateSinkObject.Save(3, &t.termios) +} + +func (t *TTYFileDescription) afterLoad() {} + +// +checklocksignore +func (t *TTYFileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.fileDescription) + stateSourceObject.Load(1, &t.session) + stateSourceObject.Load(2, &t.fgProcessGroup) + stateSourceObject.Load(3, &t.termios) +} + +func init() { + state.Register((*ConnectedEndpointRefs)(nil)) + state.Register((*virtualOwner)(nil)) + state.Register((*inode)(nil)) + state.Register((*filesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*fileDescription)(nil)) + state.Register((*inodeRefs)(nil)) + state.Register((*ConnectedEndpoint)(nil)) + state.Register((*TTYFileDescription)(nil)) +} diff --git a/pkg/sentry/fsimpl/host/host_unsafe_state_autogen.go b/pkg/sentry/fsimpl/host/host_unsafe_state_autogen.go new file mode 100644 index 000000000..b2d8c661f --- /dev/null +++ b/pkg/sentry/fsimpl/host/host_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package host diff --git a/pkg/sentry/fsimpl/host/inode_refs.go b/pkg/sentry/fsimpl/host/inode_refs.go new file mode 100644 index 000000000..112f39850 --- /dev/null +++ b/pkg/sentry/fsimpl/host/inode_refs.go @@ -0,0 +1,140 @@ +package host + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const inodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var inodeobj *inode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type inodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *inodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *inodeRefs) RefType() string { + return fmt.Sprintf("%T", inodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *inodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *inodeRefs) LogRefs() bool { + return inodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *inodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *inodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if inodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *inodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if inodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *inodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if inodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *inodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD deleted file mode 100644 index 4b577ea43..000000000 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ /dev/null @@ -1,150 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dentry_list", - out = "dentry_list.go", - package = "kernfs", - prefix = "dentry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Dentry", - "Linker": "*Dentry", - }, -) - -go_template_instance( - name = "fstree", - out = "fstree.go", - package = "kernfs", - prefix = "generic", - template = "//pkg/sentry/vfs/genericfstree:generic_fstree", - types = { - "Dentry": "Dentry", - }, -) - -go_template_instance( - name = "slot_list", - out = "slot_list.go", - package = "kernfs", - prefix = "slot", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*slot", - "Linker": "*slot", - }, -) - -go_template_instance( - name = "static_directory_refs", - out = "static_directory_refs.go", - package = "kernfs", - prefix = "StaticDirectory", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "StaticDirectory", - }, -) - -go_template_instance( - name = "dir_refs", - out = "dir_refs.go", - package = "kernfs_test", - prefix = "dir", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "dir", - }, -) - -go_template_instance( - name = "readonly_dir_refs", - out = "readonly_dir_refs.go", - package = "kernfs_test", - prefix = "readonlyDir", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "readonlyDir", - }, -) - -go_template_instance( - name = "synthetic_directory_refs", - out = "synthetic_directory_refs.go", - package = "kernfs", - prefix = "syntheticDirectory", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "syntheticDirectory", - }, -) - -go_library( - name = "kernfs", - srcs = [ - "dentry_list.go", - "dynamic_bytes_file.go", - "fd_impl_util.go", - "filesystem.go", - "fstree.go", - "inode_impl_util.go", - "kernfs.go", - "mmap_util.go", - "save_restore.go", - "slot_list.go", - "static_directory_refs.go", - "symlink.go", - "synthetic_directory.go", - "synthetic_directory_refs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "kernfs_test", - size = "small", - srcs = [ - "dir_refs.go", - "kernfs_test.go", - "readonly_dir_refs.go", - ], - deps = [ - ":kernfs", - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/usermem", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/kernfs/dentry_list.go b/pkg/sentry/fsimpl/kernfs/dentry_list.go new file mode 100644 index 000000000..e73cde1f1 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/dentry_list.go @@ -0,0 +1,221 @@ +package kernfs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type dentryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (dentryElementMapper) linkerFor(elem *Dentry) *Dentry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type dentryList struct { + head *Dentry + tail *Dentry +} + +// Reset resets list l to the empty state. +func (l *dentryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *dentryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *dentryList) Front() *Dentry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *dentryList) Back() *Dentry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *dentryList) Len() (count int) { + for e := l.Front(); e != nil; e = (dentryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *dentryList) PushFront(e *Dentry) { + linker := dentryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + dentryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *dentryList) PushBack(e *Dentry) { + linker := dentryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + dentryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *dentryList) PushBackList(m *dentryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + dentryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + dentryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *dentryList) InsertAfter(b, e *Dentry) { + bLinker := dentryElementMapper{}.linkerFor(b) + eLinker := dentryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + dentryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *dentryList) InsertBefore(a, e *Dentry) { + aLinker := dentryElementMapper{}.linkerFor(a) + eLinker := dentryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + dentryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *dentryList) Remove(e *Dentry) { + linker := dentryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + dentryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + dentryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type dentryEntry struct { + next *Dentry + prev *Dentry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *dentryEntry) Next() *Dentry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *dentryEntry) Prev() *Dentry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *dentryEntry) SetNext(elem *Dentry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *dentryEntry) SetPrev(elem *Dentry) { + e.prev = elem +} diff --git a/pkg/sentry/fsimpl/kernfs/fstree.go b/pkg/sentry/fsimpl/kernfs/fstree.go new file mode 100644 index 000000000..9dc52dabc --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/fstree.go @@ -0,0 +1,55 @@ +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is +// either d2's parent or an ancestor of d2's parent. +func genericIsAncestorDentry(d, d2 *Dentry) bool { + for d2 != nil { + if d2.parent == d { + return true + } + if d2.parent == d2 { + return false + } + d2 = d2.parent + } + return false +} + +// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. +func genericParentOrSelf(d *Dentry) *Dentry { + if d.parent != nil { + return d.parent + } + return d +} + +// PrependPath is a generic implementation of FilesystemImpl.PrependPath(). +func genericPrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath.Builder) error { + for { + if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { + return vfs.PrependPathAtVFSRootError{} + } + if mnt != nil && &d.vfsd == mnt.Root() { + return nil + } + if d.parent == nil { + return vfs.PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func genericDebugPathname(d *Dentry) string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_state_autogen.go b/pkg/sentry/fsimpl/kernfs/kernfs_state_autogen.go new file mode 100644 index 000000000..f8add23f8 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs_state_autogen.go @@ -0,0 +1,965 @@ +// automatically generated by stateify. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *dentryList) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.dentryList" +} + +func (l *dentryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *dentryList) beforeSave() {} + +// +checklocksignore +func (l *dentryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *dentryList) afterLoad() {} + +// +checklocksignore +func (l *dentryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *dentryEntry) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.dentryEntry" +} + +func (e *dentryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *dentryEntry) beforeSave() {} + +// +checklocksignore +func (e *dentryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *dentryEntry) afterLoad() {} + +// +checklocksignore +func (e *dentryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (f *DynamicBytesFile) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.DynamicBytesFile" +} + +func (f *DynamicBytesFile) StateFields() []string { + return []string{ + "InodeAttrs", + "InodeNoStatFS", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "locks", + "data", + } +} + +func (f *DynamicBytesFile) beforeSave() {} + +// +checklocksignore +func (f *DynamicBytesFile) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.InodeAttrs) + stateSinkObject.Save(1, &f.InodeNoStatFS) + stateSinkObject.Save(2, &f.InodeNoopRefCount) + stateSinkObject.Save(3, &f.InodeNotDirectory) + stateSinkObject.Save(4, &f.InodeNotSymlink) + stateSinkObject.Save(5, &f.locks) + stateSinkObject.Save(6, &f.data) +} + +func (f *DynamicBytesFile) afterLoad() {} + +// +checklocksignore +func (f *DynamicBytesFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.InodeAttrs) + stateSourceObject.Load(1, &f.InodeNoStatFS) + stateSourceObject.Load(2, &f.InodeNoopRefCount) + stateSourceObject.Load(3, &f.InodeNotDirectory) + stateSourceObject.Load(4, &f.InodeNotSymlink) + stateSourceObject.Load(5, &f.locks) + stateSourceObject.Load(6, &f.data) +} + +func (fd *DynamicBytesFD) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.DynamicBytesFD" +} + +func (fd *DynamicBytesFD) StateFields() []string { + return []string{ + "FileDescriptionDefaultImpl", + "DynamicBytesFileDescriptionImpl", + "LockFD", + "vfsfd", + "inode", + } +} + +func (fd *DynamicBytesFD) beforeSave() {} + +// +checklocksignore +func (fd *DynamicBytesFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(1, &fd.DynamicBytesFileDescriptionImpl) + stateSinkObject.Save(2, &fd.LockFD) + stateSinkObject.Save(3, &fd.vfsfd) + stateSinkObject.Save(4, &fd.inode) +} + +func (fd *DynamicBytesFD) afterLoad() {} + +// +checklocksignore +func (fd *DynamicBytesFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(1, &fd.DynamicBytesFileDescriptionImpl) + stateSourceObject.Load(2, &fd.LockFD) + stateSourceObject.Load(3, &fd.vfsfd) + stateSourceObject.Load(4, &fd.inode) +} + +func (s *SeekEndConfig) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.SeekEndConfig" +} + +func (s *SeekEndConfig) StateFields() []string { + return nil +} + +func (g *GenericDirectoryFDOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.GenericDirectoryFDOptions" +} + +func (g *GenericDirectoryFDOptions) StateFields() []string { + return []string{ + "SeekEnd", + } +} + +func (g *GenericDirectoryFDOptions) beforeSave() {} + +// +checklocksignore +func (g *GenericDirectoryFDOptions) StateSave(stateSinkObject state.Sink) { + g.beforeSave() + stateSinkObject.Save(0, &g.SeekEnd) +} + +func (g *GenericDirectoryFDOptions) afterLoad() {} + +// +checklocksignore +func (g *GenericDirectoryFDOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &g.SeekEnd) +} + +func (fd *GenericDirectoryFD) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.GenericDirectoryFD" +} + +func (fd *GenericDirectoryFD) StateFields() []string { + return []string{ + "FileDescriptionDefaultImpl", + "DirectoryFileDescriptionDefaultImpl", + "LockFD", + "seekEnd", + "vfsfd", + "children", + "off", + } +} + +func (fd *GenericDirectoryFD) beforeSave() {} + +// +checklocksignore +func (fd *GenericDirectoryFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.LockFD) + stateSinkObject.Save(3, &fd.seekEnd) + stateSinkObject.Save(4, &fd.vfsfd) + stateSinkObject.Save(5, &fd.children) + stateSinkObject.Save(6, &fd.off) +} + +func (fd *GenericDirectoryFD) afterLoad() {} + +// +checklocksignore +func (fd *GenericDirectoryFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.LockFD) + stateSourceObject.Load(3, &fd.seekEnd) + stateSourceObject.Load(4, &fd.vfsfd) + stateSourceObject.Load(5, &fd.children) + stateSourceObject.Load(6, &fd.off) +} + +func (i *InodeNoopRefCount) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeNoopRefCount" +} + +func (i *InodeNoopRefCount) StateFields() []string { + return []string{ + "InodeTemporary", + } +} + +func (i *InodeNoopRefCount) beforeSave() {} + +// +checklocksignore +func (i *InodeNoopRefCount) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeTemporary) +} + +func (i *InodeNoopRefCount) afterLoad() {} + +// +checklocksignore +func (i *InodeNoopRefCount) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeTemporary) +} + +func (i *InodeDirectoryNoNewChildren) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeDirectoryNoNewChildren" +} + +func (i *InodeDirectoryNoNewChildren) StateFields() []string { + return []string{} +} + +func (i *InodeDirectoryNoNewChildren) beforeSave() {} + +// +checklocksignore +func (i *InodeDirectoryNoNewChildren) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *InodeDirectoryNoNewChildren) afterLoad() {} + +// +checklocksignore +func (i *InodeDirectoryNoNewChildren) StateLoad(stateSourceObject state.Source) { +} + +func (i *InodeNotDirectory) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeNotDirectory" +} + +func (i *InodeNotDirectory) StateFields() []string { + return []string{ + "InodeAlwaysValid", + } +} + +func (i *InodeNotDirectory) beforeSave() {} + +// +checklocksignore +func (i *InodeNotDirectory) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeAlwaysValid) +} + +func (i *InodeNotDirectory) afterLoad() {} + +// +checklocksignore +func (i *InodeNotDirectory) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeAlwaysValid) +} + +func (i *InodeNotSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeNotSymlink" +} + +func (i *InodeNotSymlink) StateFields() []string { + return []string{} +} + +func (i *InodeNotSymlink) beforeSave() {} + +// +checklocksignore +func (i *InodeNotSymlink) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *InodeNotSymlink) afterLoad() {} + +// +checklocksignore +func (i *InodeNotSymlink) StateLoad(stateSourceObject state.Source) { +} + +func (a *InodeAttrs) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeAttrs" +} + +func (a *InodeAttrs) StateFields() []string { + return []string{ + "devMajor", + "devMinor", + "ino", + "mode", + "uid", + "gid", + "nlink", + "blockSize", + "atime", + "mtime", + "ctime", + } +} + +func (a *InodeAttrs) beforeSave() {} + +// +checklocksignore +func (a *InodeAttrs) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.devMajor) + stateSinkObject.Save(1, &a.devMinor) + stateSinkObject.Save(2, &a.ino) + stateSinkObject.Save(3, &a.mode) + stateSinkObject.Save(4, &a.uid) + stateSinkObject.Save(5, &a.gid) + stateSinkObject.Save(6, &a.nlink) + stateSinkObject.Save(7, &a.blockSize) + stateSinkObject.Save(8, &a.atime) + stateSinkObject.Save(9, &a.mtime) + stateSinkObject.Save(10, &a.ctime) +} + +func (a *InodeAttrs) afterLoad() {} + +// +checklocksignore +func (a *InodeAttrs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.devMajor) + stateSourceObject.Load(1, &a.devMinor) + stateSourceObject.Load(2, &a.ino) + stateSourceObject.Load(3, &a.mode) + stateSourceObject.Load(4, &a.uid) + stateSourceObject.Load(5, &a.gid) + stateSourceObject.Load(6, &a.nlink) + stateSourceObject.Load(7, &a.blockSize) + stateSourceObject.Load(8, &a.atime) + stateSourceObject.Load(9, &a.mtime) + stateSourceObject.Load(10, &a.ctime) +} + +func (s *slot) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.slot" +} + +func (s *slot) StateFields() []string { + return []string{ + "name", + "inode", + "static", + "slotEntry", + } +} + +func (s *slot) beforeSave() {} + +// +checklocksignore +func (s *slot) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.name) + stateSinkObject.Save(1, &s.inode) + stateSinkObject.Save(2, &s.static) + stateSinkObject.Save(3, &s.slotEntry) +} + +func (s *slot) afterLoad() {} + +// +checklocksignore +func (s *slot) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.name) + stateSourceObject.Load(1, &s.inode) + stateSourceObject.Load(2, &s.static) + stateSourceObject.Load(3, &s.slotEntry) +} + +func (o *OrderedChildrenOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.OrderedChildrenOptions" +} + +func (o *OrderedChildrenOptions) StateFields() []string { + return []string{ + "Writable", + } +} + +func (o *OrderedChildrenOptions) beforeSave() {} + +// +checklocksignore +func (o *OrderedChildrenOptions) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.Writable) +} + +func (o *OrderedChildrenOptions) afterLoad() {} + +// +checklocksignore +func (o *OrderedChildrenOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.Writable) +} + +func (o *OrderedChildren) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.OrderedChildren" +} + +func (o *OrderedChildren) StateFields() []string { + return []string{ + "writable", + "order", + "set", + } +} + +func (o *OrderedChildren) beforeSave() {} + +// +checklocksignore +func (o *OrderedChildren) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.writable) + stateSinkObject.Save(1, &o.order) + stateSinkObject.Save(2, &o.set) +} + +func (o *OrderedChildren) afterLoad() {} + +// +checklocksignore +func (o *OrderedChildren) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.writable) + stateSourceObject.Load(1, &o.order) + stateSourceObject.Load(2, &o.set) +} + +func (i *InodeSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeSymlink" +} + +func (i *InodeSymlink) StateFields() []string { + return []string{ + "InodeNotDirectory", + } +} + +func (i *InodeSymlink) beforeSave() {} + +// +checklocksignore +func (i *InodeSymlink) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeNotDirectory) +} + +func (i *InodeSymlink) afterLoad() {} + +// +checklocksignore +func (i *InodeSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeNotDirectory) +} + +func (s *StaticDirectory) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.StaticDirectory" +} + +func (s *StaticDirectory) StateFields() []string { + return []string{ + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNoStatFS", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + "StaticDirectoryRefs", + "locks", + "fdOpts", + } +} + +func (s *StaticDirectory) beforeSave() {} + +// +checklocksignore +func (s *StaticDirectory) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeAlwaysValid) + stateSinkObject.Save(1, &s.InodeAttrs) + stateSinkObject.Save(2, &s.InodeDirectoryNoNewChildren) + stateSinkObject.Save(3, &s.InodeNoStatFS) + stateSinkObject.Save(4, &s.InodeNotSymlink) + stateSinkObject.Save(5, &s.InodeTemporary) + stateSinkObject.Save(6, &s.OrderedChildren) + stateSinkObject.Save(7, &s.StaticDirectoryRefs) + stateSinkObject.Save(8, &s.locks) + stateSinkObject.Save(9, &s.fdOpts) +} + +func (s *StaticDirectory) afterLoad() {} + +// +checklocksignore +func (s *StaticDirectory) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeAlwaysValid) + stateSourceObject.Load(1, &s.InodeAttrs) + stateSourceObject.Load(2, &s.InodeDirectoryNoNewChildren) + stateSourceObject.Load(3, &s.InodeNoStatFS) + stateSourceObject.Load(4, &s.InodeNotSymlink) + stateSourceObject.Load(5, &s.InodeTemporary) + stateSourceObject.Load(6, &s.OrderedChildren) + stateSourceObject.Load(7, &s.StaticDirectoryRefs) + stateSourceObject.Load(8, &s.locks) + stateSourceObject.Load(9, &s.fdOpts) +} + +func (i *InodeAlwaysValid) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeAlwaysValid" +} + +func (i *InodeAlwaysValid) StateFields() []string { + return []string{} +} + +func (i *InodeAlwaysValid) beforeSave() {} + +// +checklocksignore +func (i *InodeAlwaysValid) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *InodeAlwaysValid) afterLoad() {} + +// +checklocksignore +func (i *InodeAlwaysValid) StateLoad(stateSourceObject state.Source) { +} + +func (i *InodeTemporary) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeTemporary" +} + +func (i *InodeTemporary) StateFields() []string { + return []string{} +} + +func (i *InodeTemporary) beforeSave() {} + +// +checklocksignore +func (i *InodeTemporary) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *InodeTemporary) afterLoad() {} + +// +checklocksignore +func (i *InodeTemporary) StateLoad(stateSourceObject state.Source) { +} + +func (i *InodeNoStatFS) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.InodeNoStatFS" +} + +func (i *InodeNoStatFS) StateFields() []string { + return []string{} +} + +func (i *InodeNoStatFS) beforeSave() {} + +// +checklocksignore +func (i *InodeNoStatFS) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *InodeNoStatFS) afterLoad() {} + +// +checklocksignore +func (i *InodeNoStatFS) StateLoad(stateSourceObject state.Source) { +} + +func (fs *Filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.Filesystem" +} + +func (fs *Filesystem) StateFields() []string { + return []string{ + "vfsfs", + "deferredDecRefs", + "nextInoMinusOne", + "cachedDentries", + "cachedDentriesLen", + "MaxCachedDentries", + "root", + } +} + +func (fs *Filesystem) beforeSave() {} + +// +checklocksignore +func (fs *Filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.vfsfs) + stateSinkObject.Save(1, &fs.deferredDecRefs) + stateSinkObject.Save(2, &fs.nextInoMinusOne) + stateSinkObject.Save(3, &fs.cachedDentries) + stateSinkObject.Save(4, &fs.cachedDentriesLen) + stateSinkObject.Save(5, &fs.MaxCachedDentries) + stateSinkObject.Save(6, &fs.root) +} + +func (fs *Filesystem) afterLoad() {} + +// +checklocksignore +func (fs *Filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.vfsfs) + stateSourceObject.Load(1, &fs.deferredDecRefs) + stateSourceObject.Load(2, &fs.nextInoMinusOne) + stateSourceObject.Load(3, &fs.cachedDentries) + stateSourceObject.Load(4, &fs.cachedDentriesLen) + stateSourceObject.Load(5, &fs.MaxCachedDentries) + stateSourceObject.Load(6, &fs.root) +} + +func (d *Dentry) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.Dentry" +} + +func (d *Dentry) StateFields() []string { + return []string{ + "vfsd", + "refs", + "fs", + "flags", + "parent", + "name", + "cached", + "dentryEntry", + "children", + "inode", + } +} + +func (d *Dentry) beforeSave() {} + +// +checklocksignore +func (d *Dentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.vfsd) + stateSinkObject.Save(1, &d.refs) + stateSinkObject.Save(2, &d.fs) + stateSinkObject.Save(3, &d.flags) + stateSinkObject.Save(4, &d.parent) + stateSinkObject.Save(5, &d.name) + stateSinkObject.Save(6, &d.cached) + stateSinkObject.Save(7, &d.dentryEntry) + stateSinkObject.Save(8, &d.children) + stateSinkObject.Save(9, &d.inode) +} + +// +checklocksignore +func (d *Dentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.vfsd) + stateSourceObject.Load(1, &d.refs) + stateSourceObject.Load(2, &d.fs) + stateSourceObject.Load(3, &d.flags) + stateSourceObject.Load(4, &d.parent) + stateSourceObject.Load(5, &d.name) + stateSourceObject.Load(6, &d.cached) + stateSourceObject.Load(7, &d.dentryEntry) + stateSourceObject.Load(8, &d.children) + stateSourceObject.Load(9, &d.inode) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (i *inodePlatformFile) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.inodePlatformFile" +} + +func (i *inodePlatformFile) StateFields() []string { + return []string{ + "hostFD", + "fdRefs", + "fileMapper", + } +} + +func (i *inodePlatformFile) beforeSave() {} + +// +checklocksignore +func (i *inodePlatformFile) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.hostFD) + stateSinkObject.Save(1, &i.fdRefs) + stateSinkObject.Save(2, &i.fileMapper) +} + +// +checklocksignore +func (i *inodePlatformFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.hostFD) + stateSourceObject.Load(1, &i.fdRefs) + stateSourceObject.Load(2, &i.fileMapper) + stateSourceObject.AfterLoad(i.afterLoad) +} + +func (i *CachedMappable) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.CachedMappable" +} + +func (i *CachedMappable) StateFields() []string { + return []string{ + "mappings", + "pf", + } +} + +func (i *CachedMappable) beforeSave() {} + +// +checklocksignore +func (i *CachedMappable) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.mappings) + stateSinkObject.Save(1, &i.pf) +} + +func (i *CachedMappable) afterLoad() {} + +// +checklocksignore +func (i *CachedMappable) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.mappings) + stateSourceObject.Load(1, &i.pf) +} + +func (l *slotList) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.slotList" +} + +func (l *slotList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *slotList) beforeSave() {} + +// +checklocksignore +func (l *slotList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *slotList) afterLoad() {} + +// +checklocksignore +func (l *slotList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *slotEntry) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.slotEntry" +} + +func (e *slotEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *slotEntry) beforeSave() {} + +// +checklocksignore +func (e *slotEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *slotEntry) afterLoad() {} + +// +checklocksignore +func (e *slotEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (r *StaticDirectoryRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.StaticDirectoryRefs" +} + +func (r *StaticDirectoryRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *StaticDirectoryRefs) beforeSave() {} + +// +checklocksignore +func (r *StaticDirectoryRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *StaticDirectoryRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (s *StaticSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.StaticSymlink" +} + +func (s *StaticSymlink) StateFields() []string { + return []string{ + "InodeAttrs", + "InodeNoopRefCount", + "InodeSymlink", + "InodeNoStatFS", + "target", + } +} + +func (s *StaticSymlink) beforeSave() {} + +// +checklocksignore +func (s *StaticSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeAttrs) + stateSinkObject.Save(1, &s.InodeNoopRefCount) + stateSinkObject.Save(2, &s.InodeSymlink) + stateSinkObject.Save(3, &s.InodeNoStatFS) + stateSinkObject.Save(4, &s.target) +} + +func (s *StaticSymlink) afterLoad() {} + +// +checklocksignore +func (s *StaticSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeAttrs) + stateSourceObject.Load(1, &s.InodeNoopRefCount) + stateSourceObject.Load(2, &s.InodeSymlink) + stateSourceObject.Load(3, &s.InodeNoStatFS) + stateSourceObject.Load(4, &s.target) +} + +func (dir *syntheticDirectory) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.syntheticDirectory" +} + +func (dir *syntheticDirectory) StateFields() []string { + return []string{ + "InodeAlwaysValid", + "InodeAttrs", + "InodeNoStatFS", + "InodeNotSymlink", + "OrderedChildren", + "syntheticDirectoryRefs", + "locks", + } +} + +func (dir *syntheticDirectory) beforeSave() {} + +// +checklocksignore +func (dir *syntheticDirectory) StateSave(stateSinkObject state.Sink) { + dir.beforeSave() + stateSinkObject.Save(0, &dir.InodeAlwaysValid) + stateSinkObject.Save(1, &dir.InodeAttrs) + stateSinkObject.Save(2, &dir.InodeNoStatFS) + stateSinkObject.Save(3, &dir.InodeNotSymlink) + stateSinkObject.Save(4, &dir.OrderedChildren) + stateSinkObject.Save(5, &dir.syntheticDirectoryRefs) + stateSinkObject.Save(6, &dir.locks) +} + +func (dir *syntheticDirectory) afterLoad() {} + +// +checklocksignore +func (dir *syntheticDirectory) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &dir.InodeAlwaysValid) + stateSourceObject.Load(1, &dir.InodeAttrs) + stateSourceObject.Load(2, &dir.InodeNoStatFS) + stateSourceObject.Load(3, &dir.InodeNotSymlink) + stateSourceObject.Load(4, &dir.OrderedChildren) + stateSourceObject.Load(5, &dir.syntheticDirectoryRefs) + stateSourceObject.Load(6, &dir.locks) +} + +func (r *syntheticDirectoryRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/kernfs.syntheticDirectoryRefs" +} + +func (r *syntheticDirectoryRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *syntheticDirectoryRefs) beforeSave() {} + +// +checklocksignore +func (r *syntheticDirectoryRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *syntheticDirectoryRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func init() { + state.Register((*dentryList)(nil)) + state.Register((*dentryEntry)(nil)) + state.Register((*DynamicBytesFile)(nil)) + state.Register((*DynamicBytesFD)(nil)) + state.Register((*SeekEndConfig)(nil)) + state.Register((*GenericDirectoryFDOptions)(nil)) + state.Register((*GenericDirectoryFD)(nil)) + state.Register((*InodeNoopRefCount)(nil)) + state.Register((*InodeDirectoryNoNewChildren)(nil)) + state.Register((*InodeNotDirectory)(nil)) + state.Register((*InodeNotSymlink)(nil)) + state.Register((*InodeAttrs)(nil)) + state.Register((*slot)(nil)) + state.Register((*OrderedChildrenOptions)(nil)) + state.Register((*OrderedChildren)(nil)) + state.Register((*InodeSymlink)(nil)) + state.Register((*StaticDirectory)(nil)) + state.Register((*InodeAlwaysValid)(nil)) + state.Register((*InodeTemporary)(nil)) + state.Register((*InodeNoStatFS)(nil)) + state.Register((*Filesystem)(nil)) + state.Register((*Dentry)(nil)) + state.Register((*inodePlatformFile)(nil)) + state.Register((*CachedMappable)(nil)) + state.Register((*slotList)(nil)) + state.Register((*slotEntry)(nil)) + state.Register((*StaticDirectoryRefs)(nil)) + state.Register((*StaticSymlink)(nil)) + state.Register((*syntheticDirectory)(nil)) + state.Register((*syntheticDirectoryRefs)(nil)) +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go deleted file mode 100644 index a2aba9321..000000000 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernfs_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -const defaultMode linux.FileMode = 01777 -const staticFileContent = "This is sample content for a static test file." - -// RootDentryFn is a generator function for creating the root dentry of a test -// filesystem. See newTestSystem. -type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode - -// newTestSystem sets up a minimal environment for running a test, including an -// instance of a test filesystem. Tests can control the contents of the -// filesystem by providing an appropriate rootFn, which should return a -// pre-populated root dentry. -func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System { - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - v := &vfs.VirtualFilesystem{} - if err := v.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.MountOptions{}) - if err != nil { - t.Fatalf("Failed to create testfs root mount: %v", err) - } - return testutil.NewSystem(ctx, t, v, mns) -} - -type fsType struct { - rootFn RootDentryFn -} - -type filesystem struct { - kernfs.Filesystem -} - -// MountOptions implements vfs.FilesystemImpl.MountOptions. -func (fs *filesystem) MountOptions() string { - return "" -} - -type file struct { - kernfs.DynamicBytesFile - content string -} - -func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode { - f := &file{} - f.content = content - f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) - return f -} - -func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%s", f.content) - return nil -} - -type attrs struct { - kernfs.InodeAttrs -} - -func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { - return linuxerr.EPERM -} - -type readonlyDir struct { - readonlyDirRefs - attrs - kernfs.InodeAlwaysValid - kernfs.InodeDirectoryNoNewChildren - kernfs.InodeNoStatFS - kernfs.InodeNotSymlink - kernfs.InodeTemporary - kernfs.OrderedChildren - - locks vfs.FileLocks -} - -func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { - dir := &readonlyDir{} - dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) - dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - dir.InitRefs() - dir.IncLinks(dir.OrderedChildren.Populate(contents)) - return dir -} - -func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ - SeekEnd: kernfs.SeekEndStaticEntries, - }) - if err != nil { - return nil, err - } - return fd.VFSFileDescription(), nil -} - -func (d *readonlyDir) DecRef(ctx context.Context) { - d.readonlyDirRefs.DecRef(func() { d.Destroy(ctx) }) -} - -type dir struct { - dirRefs - attrs - kernfs.InodeAlwaysValid - kernfs.InodeNotSymlink - kernfs.InodeNoStatFS - kernfs.InodeTemporary - kernfs.OrderedChildren - - locks vfs.FileLocks - - fs *filesystem -} - -func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { - dir := &dir{} - dir.fs = fs - dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) - dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) - dir.InitRefs() - - dir.IncLinks(dir.OrderedChildren.Populate(contents)) - return dir -} - -func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ - SeekEnd: kernfs.SeekEndStaticEntries, - }) - if err != nil { - return nil, err - } - return fd.VFSFileDescription(), nil -} - -func (d *dir) DecRef(ctx context.Context) { - d.dirRefs.DecRef(func() { d.Destroy(ctx) }) -} - -func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { - creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(ctx, creds, opts.Mode, nil) - if err := d.OrderedChildren.Insert(name, dir); err != nil { - dir.DecRef(ctx) - return nil, err - } - d.TouchCMtime(ctx) - d.IncLinks(1) - return dir, nil -} - -func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) { - creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(ctx, creds, "") - if err := d.OrderedChildren.Insert(name, f); err != nil { - f.DecRef(ctx) - return nil, err - } - d.TouchCMtime(ctx) - return f, nil -} - -func (*dir) NewLink(context.Context, string, kernfs.Inode) (kernfs.Inode, error) { - return nil, linuxerr.EPERM -} - -func (*dir) NewSymlink(context.Context, string, string) (kernfs.Inode, error) { - return nil, linuxerr.EPERM -} - -func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (kernfs.Inode, error) { - return nil, linuxerr.EPERM -} - -func (fsType) Name() string { - return "kernfs" -} - -func (fsType) Release(ctx context.Context) {} - -func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - fs := &filesystem{} - fs.VFSFilesystem().Init(vfsObj, &fst, fs) - root := fst.rootFn(ctx, creds, fs) - var d kernfs.Dentry - d.Init(&fs.Filesystem, root) - return fs.VFSFilesystem(), d.VFSDentry(), nil -} - -// -------------------- Remainder of the file are test cases -------------------- - -func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(ctx, creds, staticFileContent), - }) - }) - defer sys.Destroy() - sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef(sys.Ctx) -} - -func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(ctx, creds, 0755, nil), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("dir1/a new directory") - if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { - t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) - } - sys.GetDentryOrDie(pop).DecRef(sys.Ctx) -} - -func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(ctx, creds, staticFileContent), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("file1") - fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) - } - defer fd.DecRef(sys.Ctx) - - content, err := sys.ReadToEnd(fd) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - if diff := cmp.Diff(staticFileContent, content); diff != "" { - t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) - } -} - -func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(ctx, creds, 0755, nil), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("dir1/newfile") - opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode} - fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, opts) - if err != nil { - t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err) - } - - // Close the file. The file should persist. - fd.DecRef(sys.Ctx) - - fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) - } - fd.DecRef(sys.Ctx) -} - -func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(ctx, creds, 0755, nil) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("/") - fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) - } - defer fd.DecRef(sys.Ctx) - - // Read/Write should fail for directory FDs. - if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) { - t.Fatalf("Read for directory FD failed with unexpected error: %v", err) - } - if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); !linuxerr.Equals(linuxerr.EBADF, err) { - t.Fatalf("Write for directory FD failed with unexpected error: %v", err) - } -} - -func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ - // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil), - "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "dir3": fs.newDir(ctx, creds, 0755, nil), - }), - "file1": fs.newFile(ctx, creds, staticFileContent), - }) - }) - defer sys.Destroy() - - pop := sys.PathOpAtRoot("/") - sys.AssertAllDirentTypes(sys.ListDirents(pop), map[string]testutil.DirentType{ - "dir1": linux.DT_DIR, - "dir2": linux.DT_DIR, - "file1": linux.DT_REG, - }) -} - -func TestDirWalkDentryTree(t *testing.T) { - sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(ctx, creds, 0755, nil), - "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(ctx, creds, staticFileContent), - "dir3": fs.newDir(ctx, creds, 0755, nil), - }), - }) - }) - defer sys.Destroy() - - testWalk := func(from *kernfs.Dentry, getDentryPath, walkPath string, expectedErr error) { - var d *kernfs.Dentry - if getDentryPath != "" { - pop := sys.PathOpAtRoot(getDentryPath) - vd := sys.GetDentryOrDie(pop) - defer vd.DecRef(sys.Ctx) - d = vd.Dentry().Impl().(*kernfs.Dentry) - } - - match, err := from.WalkDentryTree(sys.Ctx, sys.VFS, fspath.Parse(walkPath)) - if err == nil { - defer match.DecRef(sys.Ctx) - } - - if err != expectedErr { - t.Fatalf("WalkDentryTree from %q to %q (with expected error: %v) unexpected error, want: %v, got: %v", from.FSLocalPath(), walkPath, expectedErr, expectedErr, err) - } - if expectedErr != nil { - return - } - - if d != match { - t.Fatalf("WalkDentryTree from %q to %q (with expected error: %v) found unexpected dentry; want: %v, got: %v", from.FSLocalPath(), walkPath, expectedErr, d, match) - } - } - - rootD := sys.Root.Dentry().Impl().(*kernfs.Dentry) - - testWalk(rootD, "dir1", "/dir1", nil) - testWalk(rootD, "", "/dir-non-existent", linuxerr.ENOENT) - testWalk(rootD, "", "/dir1/child-non-existent", linuxerr.ENOENT) - testWalk(rootD, "", "/dir2/inner-non-existent/dir3", linuxerr.ENOENT) - - testWalk(rootD, "dir2/dir3", "/dir2/../dir2/dir3", nil) - testWalk(rootD, "dir2/dir3", "/dir2/././dir3", nil) - testWalk(rootD, "dir2/dir3", "/dir2/././dir3/.././dir3", nil) - - pop := sys.PathOpAtRoot("dir2") - dir2VD := sys.GetDentryOrDie(pop) - defer dir2VD.DecRef(sys.Ctx) - dir2D := dir2VD.Dentry().Impl().(*kernfs.Dentry) - - testWalk(dir2D, "dir2/dir3", "/dir3", nil) - testWalk(dir2D, "dir2/dir3", "/../../../dir3", nil) - testWalk(dir2D, "dir2/file1", "/file1", nil) - testWalk(dir2D, "dir2/file1", "file1", nil) -} diff --git a/pkg/sentry/fsimpl/kernfs/slot_list.go b/pkg/sentry/fsimpl/kernfs/slot_list.go new file mode 100644 index 000000000..181fe7c8f --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/slot_list.go @@ -0,0 +1,221 @@ +package kernfs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type slotElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (slotElementMapper) linkerFor(elem *slot) *slot { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type slotList struct { + head *slot + tail *slot +} + +// Reset resets list l to the empty state. +func (l *slotList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *slotList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *slotList) Front() *slot { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *slotList) Back() *slot { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *slotList) Len() (count int) { + for e := l.Front(); e != nil; e = (slotElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *slotList) PushFront(e *slot) { + linker := slotElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + slotElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *slotList) PushBack(e *slot) { + linker := slotElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + slotElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *slotList) PushBackList(m *slotList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + slotElementMapper{}.linkerFor(l.tail).SetNext(m.head) + slotElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *slotList) InsertAfter(b, e *slot) { + bLinker := slotElementMapper{}.linkerFor(b) + eLinker := slotElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + slotElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *slotList) InsertBefore(a, e *slot) { + aLinker := slotElementMapper{}.linkerFor(a) + eLinker := slotElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + slotElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *slotList) Remove(e *slot) { + linker := slotElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + slotElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + slotElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type slotEntry struct { + next *slot + prev *slot +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *slotEntry) Next() *slot { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *slotEntry) Prev() *slot { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *slotEntry) SetNext(elem *slot) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *slotEntry) SetPrev(elem *slot) { + e.prev = elem +} diff --git a/pkg/sentry/fsimpl/kernfs/static_directory_refs.go b/pkg/sentry/fsimpl/kernfs/static_directory_refs.go new file mode 100644 index 000000000..69534a2d2 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/static_directory_refs.go @@ -0,0 +1,140 @@ +package kernfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const StaticDirectoryenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var StaticDirectoryobj *StaticDirectory + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type StaticDirectoryRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *StaticDirectoryRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *StaticDirectoryRefs) RefType() string { + return fmt.Sprintf("%T", StaticDirectoryobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *StaticDirectoryRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *StaticDirectoryRefs) LogRefs() bool { + return StaticDirectoryenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *StaticDirectoryRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *StaticDirectoryRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if StaticDirectoryenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *StaticDirectoryRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if StaticDirectoryenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *StaticDirectoryRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if StaticDirectoryenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *StaticDirectoryRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory_refs.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory_refs.go new file mode 100644 index 000000000..3c5fdf15e --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory_refs.go @@ -0,0 +1,140 @@ +package kernfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const syntheticDirectoryenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var syntheticDirectoryobj *syntheticDirectory + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type syntheticDirectoryRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *syntheticDirectoryRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *syntheticDirectoryRefs) RefType() string { + return fmt.Sprintf("%T", syntheticDirectoryobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *syntheticDirectoryRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *syntheticDirectoryRefs) LogRefs() bool { + return syntheticDirectoryenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *syntheticDirectoryRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *syntheticDirectoryRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if syntheticDirectoryenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *syntheticDirectoryRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if syntheticDirectoryenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *syntheticDirectoryRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if syntheticDirectoryenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *syntheticDirectoryRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD deleted file mode 100644 index d16dfef9b..000000000 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ /dev/null @@ -1,48 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "fstree", - out = "fstree.go", - package = "overlay", - prefix = "generic", - template = "//pkg/sentry/vfs/genericfstree:generic_fstree", - types = { - "Dentry": "dentry", - }, -) - -go_library( - name = "overlay", - srcs = [ - "copy_up.go", - "directory.go", - "filesystem.go", - "fstree.go", - "overlay.go", - "regular_file.go", - "save_restore.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sentry/arch", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fsimpl/overlay/fstree.go b/pkg/sentry/fsimpl/overlay/fstree.go new file mode 100644 index 000000000..c3eb062ed --- /dev/null +++ b/pkg/sentry/fsimpl/overlay/fstree.go @@ -0,0 +1,55 @@ +package overlay + +import ( + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is +// either d2's parent or an ancestor of d2's parent. +func genericIsAncestorDentry(d, d2 *dentry) bool { + for d2 != nil { + if d2.parent == d { + return true + } + if d2.parent == d2 { + return false + } + d2 = d2.parent + } + return false +} + +// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. +func genericParentOrSelf(d *dentry) *dentry { + if d.parent != nil { + return d.parent + } + return d +} + +// PrependPath is a generic implementation of FilesystemImpl.PrependPath(). +func genericPrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *dentry, b *fspath.Builder) error { + for { + if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { + return vfs.PrependPathAtVFSRootError{} + } + if mnt != nil && &d.vfsd == mnt.Root() { + return nil + } + if d.parent == nil { + return vfs.PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func genericDebugPathname(d *dentry) string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/fsimpl/overlay/overlay_state_autogen.go b/pkg/sentry/fsimpl/overlay/overlay_state_autogen.go new file mode 100644 index 000000000..923a2e71a --- /dev/null +++ b/pkg/sentry/fsimpl/overlay/overlay_state_autogen.go @@ -0,0 +1,321 @@ +// automatically generated by stateify. + +package overlay + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fd *directoryFD) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.directoryFD" +} + +func (fd *directoryFD) StateFields() []string { + return []string{ + "fileDescription", + "DirectoryFileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "off", + "dirents", + } +} + +func (fd *directoryFD) beforeSave() {} + +// +checklocksignore +func (fd *directoryFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.off) + stateSinkObject.Save(4, &fd.dirents) +} + +func (fd *directoryFD) afterLoad() {} + +// +checklocksignore +func (fd *directoryFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.off) + stateSourceObject.Load(4, &fd.dirents) +} + +func (fstype *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.FilesystemType" +} + +func (fstype *FilesystemType) StateFields() []string { + return []string{} +} + +func (fstype *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fstype *FilesystemType) StateSave(stateSinkObject state.Sink) { + fstype.beforeSave() +} + +func (fstype *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fstype *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (f *FilesystemOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.FilesystemOptions" +} + +func (f *FilesystemOptions) StateFields() []string { + return []string{ + "UpperRoot", + "LowerRoots", + } +} + +func (f *FilesystemOptions) beforeSave() {} + +// +checklocksignore +func (f *FilesystemOptions) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.UpperRoot) + stateSinkObject.Save(1, &f.LowerRoots) +} + +func (f *FilesystemOptions) afterLoad() {} + +// +checklocksignore +func (f *FilesystemOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.UpperRoot) + stateSourceObject.Load(1, &f.LowerRoots) +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "vfsfs", + "opts", + "creds", + "privateDevMinors", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.vfsfs) + stateSinkObject.Save(1, &fs.opts) + stateSinkObject.Save(2, &fs.creds) + stateSinkObject.Save(3, &fs.privateDevMinors) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.vfsfs) + stateSourceObject.Load(1, &fs.opts) + stateSourceObject.Load(2, &fs.creds) + stateSourceObject.Load(3, &fs.privateDevMinors) +} + +func (l *layerDevNumber) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.layerDevNumber" +} + +func (l *layerDevNumber) StateFields() []string { + return []string{ + "major", + "minor", + } +} + +func (l *layerDevNumber) beforeSave() {} + +// +checklocksignore +func (l *layerDevNumber) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.major) + stateSinkObject.Save(1, &l.minor) +} + +func (l *layerDevNumber) afterLoad() {} + +// +checklocksignore +func (l *layerDevNumber) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.major) + stateSourceObject.Load(1, &l.minor) +} + +func (d *dentry) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.dentry" +} + +func (d *dentry) StateFields() []string { + return []string{ + "vfsd", + "refs", + "fs", + "mode", + "uid", + "gid", + "copiedUp", + "parent", + "name", + "children", + "dirents", + "upperVD", + "lowerVDs", + "inlineLowerVDs", + "devMajor", + "devMinor", + "ino", + "lowerMappings", + "wrappedMappable", + "isMappable", + "locks", + "watches", + } +} + +func (d *dentry) beforeSave() {} + +// +checklocksignore +func (d *dentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.vfsd) + stateSinkObject.Save(1, &d.refs) + stateSinkObject.Save(2, &d.fs) + stateSinkObject.Save(3, &d.mode) + stateSinkObject.Save(4, &d.uid) + stateSinkObject.Save(5, &d.gid) + stateSinkObject.Save(6, &d.copiedUp) + stateSinkObject.Save(7, &d.parent) + stateSinkObject.Save(8, &d.name) + stateSinkObject.Save(9, &d.children) + stateSinkObject.Save(10, &d.dirents) + stateSinkObject.Save(11, &d.upperVD) + stateSinkObject.Save(12, &d.lowerVDs) + stateSinkObject.Save(13, &d.inlineLowerVDs) + stateSinkObject.Save(14, &d.devMajor) + stateSinkObject.Save(15, &d.devMinor) + stateSinkObject.Save(16, &d.ino) + stateSinkObject.Save(17, &d.lowerMappings) + stateSinkObject.Save(18, &d.wrappedMappable) + stateSinkObject.Save(19, &d.isMappable) + stateSinkObject.Save(20, &d.locks) + stateSinkObject.Save(21, &d.watches) +} + +// +checklocksignore +func (d *dentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.vfsd) + stateSourceObject.Load(1, &d.refs) + stateSourceObject.Load(2, &d.fs) + stateSourceObject.Load(3, &d.mode) + stateSourceObject.Load(4, &d.uid) + stateSourceObject.Load(5, &d.gid) + stateSourceObject.Load(6, &d.copiedUp) + stateSourceObject.Load(7, &d.parent) + stateSourceObject.Load(8, &d.name) + stateSourceObject.Load(9, &d.children) + stateSourceObject.Load(10, &d.dirents) + stateSourceObject.Load(11, &d.upperVD) + stateSourceObject.Load(12, &d.lowerVDs) + stateSourceObject.Load(13, &d.inlineLowerVDs) + stateSourceObject.Load(14, &d.devMajor) + stateSourceObject.Load(15, &d.devMinor) + stateSourceObject.Load(16, &d.ino) + stateSourceObject.Load(17, &d.lowerMappings) + stateSourceObject.Load(18, &d.wrappedMappable) + stateSourceObject.Load(19, &d.isMappable) + stateSourceObject.Load(20, &d.locks) + stateSourceObject.Load(21, &d.watches) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (fd *fileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.fileDescription" +} + +func (fd *fileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + } +} + +func (fd *fileDescription) beforeSave() {} + +// +checklocksignore +func (fd *fileDescription) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.LockFD) +} + +func (fd *fileDescription) afterLoad() {} + +// +checklocksignore +func (fd *fileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.LockFD) +} + +func (fd *regularFileFD) StateTypeName() string { + return "pkg/sentry/fsimpl/overlay.regularFileFD" +} + +func (fd *regularFileFD) StateFields() []string { + return []string{ + "fileDescription", + "copiedUp", + "cachedFD", + "cachedFlags", + "lowerWaiters", + } +} + +func (fd *regularFileFD) beforeSave() {} + +// +checklocksignore +func (fd *regularFileFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.copiedUp) + stateSinkObject.Save(2, &fd.cachedFD) + stateSinkObject.Save(3, &fd.cachedFlags) + stateSinkObject.Save(4, &fd.lowerWaiters) +} + +func (fd *regularFileFD) afterLoad() {} + +// +checklocksignore +func (fd *regularFileFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.copiedUp) + stateSourceObject.Load(2, &fd.cachedFD) + stateSourceObject.Load(3, &fd.cachedFlags) + stateSourceObject.Load(4, &fd.lowerWaiters) +} + +func init() { + state.Register((*directoryFD)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*FilesystemOptions)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*layerDevNumber)(nil)) + state.Register((*dentry)(nil)) + state.Register((*fileDescription)(nil)) + state.Register((*regularFileFD)(nil)) +} diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD deleted file mode 100644 index a50510031..000000000 --- a/pkg/sentry/fsimpl/pipefs/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "pipefs", - srcs = ["pipefs.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/fsimpl/pipefs/pipefs_state_autogen.go b/pkg/sentry/fsimpl/pipefs/pipefs_state_autogen.go new file mode 100644 index 000000000..5f9e117c3 --- /dev/null +++ b/pkg/sentry/fsimpl/pipefs/pipefs_state_autogen.go @@ -0,0 +1,111 @@ +// automatically generated by stateify. + +package pipefs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *filesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/pipefs.filesystemType" +} + +func (f *filesystemType) StateFields() []string { + return []string{} +} + +func (f *filesystemType) beforeSave() {} + +// +checklocksignore +func (f *filesystemType) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *filesystemType) afterLoad() {} + +// +checklocksignore +func (f *filesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/pipefs.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (i *inode) StateTypeName() string { + return "pkg/sentry/fsimpl/pipefs.inode" +} + +func (i *inode) StateFields() []string { + return []string{ + "InodeNotDirectory", + "InodeNotSymlink", + "InodeNoopRefCount", + "locks", + "pipe", + "ino", + "uid", + "gid", + "ctime", + } +} + +func (i *inode) beforeSave() {} + +// +checklocksignore +func (i *inode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeNotDirectory) + stateSinkObject.Save(1, &i.InodeNotSymlink) + stateSinkObject.Save(2, &i.InodeNoopRefCount) + stateSinkObject.Save(3, &i.locks) + stateSinkObject.Save(4, &i.pipe) + stateSinkObject.Save(5, &i.ino) + stateSinkObject.Save(6, &i.uid) + stateSinkObject.Save(7, &i.gid) + stateSinkObject.Save(8, &i.ctime) +} + +func (i *inode) afterLoad() {} + +// +checklocksignore +func (i *inode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeNotDirectory) + stateSourceObject.Load(1, &i.InodeNotSymlink) + stateSourceObject.Load(2, &i.InodeNoopRefCount) + stateSourceObject.Load(3, &i.locks) + stateSourceObject.Load(4, &i.pipe) + stateSourceObject.Load(5, &i.ino) + stateSourceObject.Load(6, &i.uid) + stateSourceObject.Load(7, &i.gid) + stateSourceObject.Load(8, &i.ctime) +} + +func init() { + state.Register((*filesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*inode)(nil)) +} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD deleted file mode 100644 index 95cfbdc42..000000000 --- a/pkg/sentry/fsimpl/proc/BUILD +++ /dev/null @@ -1,133 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "fd_dir_inode_refs", - out = "fd_dir_inode_refs.go", - package = "proc", - prefix = "fdDirInode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "fdDirInode", - }, -) - -go_template_instance( - name = "fd_info_dir_inode_refs", - out = "fd_info_dir_inode_refs.go", - package = "proc", - prefix = "fdInfoDirInode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "fdInfoDirInode", - }, -) - -go_template_instance( - name = "subtasks_inode_refs", - out = "subtasks_inode_refs.go", - package = "proc", - prefix = "subtasksInode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "subtasksInode", - }, -) - -go_template_instance( - name = "task_inode_refs", - out = "task_inode_refs.go", - package = "proc", - prefix = "taskInode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "taskInode", - }, -) - -go_template_instance( - name = "tasks_inode_refs", - out = "tasks_inode_refs.go", - package = "proc", - prefix = "tasksInode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "tasksInode", - }, -) - -go_library( - name = "proc", - srcs = [ - "fd_dir_inode_refs.go", - "fd_info_dir_inode_refs.go", - "filesystem.go", - "subtasks.go", - "subtasks_inode_refs.go", - "task.go", - "task_fds.go", - "task_files.go", - "task_inode_refs.go", - "task_net.go", - "tasks.go", - "tasks_files.go", - "tasks_inode_refs.go", - "tasks_sys.go", - "yama.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsbridge", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/tcpip/header", - "//pkg/tcpip/network/ipv4", - "//pkg/usermem", - ], -) - -go_test( - name = "proc_test", - size = "small", - srcs = [ - "tasks_sys_test.go", - "tasks_test.go", - ], - library = ":proc", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/proc/fd_dir_inode_refs.go b/pkg/sentry/fsimpl/proc/fd_dir_inode_refs.go new file mode 100644 index 000000000..61138f055 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/fd_dir_inode_refs.go @@ -0,0 +1,140 @@ +package proc + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const fdDirInodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var fdDirInodeobj *fdDirInode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type fdDirInodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *fdDirInodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *fdDirInodeRefs) RefType() string { + return fmt.Sprintf("%T", fdDirInodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *fdDirInodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *fdDirInodeRefs) LogRefs() bool { + return fdDirInodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *fdDirInodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *fdDirInodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if fdDirInodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *fdDirInodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if fdDirInodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *fdDirInodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if fdDirInodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *fdDirInodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/proc/fd_info_dir_inode_refs.go b/pkg/sentry/fsimpl/proc/fd_info_dir_inode_refs.go new file mode 100644 index 000000000..53fb0910a --- /dev/null +++ b/pkg/sentry/fsimpl/proc/fd_info_dir_inode_refs.go @@ -0,0 +1,140 @@ +package proc + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const fdInfoDirInodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var fdInfoDirInodeobj *fdInfoDirInode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type fdInfoDirInodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *fdInfoDirInodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *fdInfoDirInodeRefs) RefType() string { + return fmt.Sprintf("%T", fdInfoDirInodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *fdInfoDirInodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *fdInfoDirInodeRefs) LogRefs() bool { + return fdInfoDirInodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *fdInfoDirInodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *fdInfoDirInodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if fdInfoDirInodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *fdInfoDirInodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if fdInfoDirInodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *fdInfoDirInodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if fdInfoDirInodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *fdInfoDirInodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/proc/proc_state_autogen.go b/pkg/sentry/fsimpl/proc/proc_state_autogen.go new file mode 100644 index 000000000..e32e7671c --- /dev/null +++ b/pkg/sentry/fsimpl/proc/proc_state_autogen.go @@ -0,0 +1,2454 @@ +// automatically generated by stateify. + +package proc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *fdDirInodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdDirInodeRefs" +} + +func (r *fdDirInodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *fdDirInodeRefs) beforeSave() {} + +// +checklocksignore +func (r *fdDirInodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *fdDirInodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (r *fdInfoDirInodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdInfoDirInodeRefs" +} + +func (r *fdInfoDirInodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *fdInfoDirInodeRefs) beforeSave() {} + +// +checklocksignore +func (r *fdInfoDirInodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *fdInfoDirInodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (ft *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.FilesystemType" +} + +func (ft *FilesystemType) StateFields() []string { + return []string{} +} + +func (ft *FilesystemType) beforeSave() {} + +// +checklocksignore +func (ft *FilesystemType) StateSave(stateSinkObject state.Sink) { + ft.beforeSave() +} + +func (ft *FilesystemType) afterLoad() {} + +// +checklocksignore +func (ft *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (s *staticFile) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.staticFile" +} + +func (s *staticFile) StateFields() []string { + return []string{ + "DynamicBytesFile", + "StaticData", + } +} + +func (s *staticFile) beforeSave() {} + +// +checklocksignore +func (s *staticFile) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.DynamicBytesFile) + stateSinkObject.Save(1, &s.StaticData) +} + +func (s *staticFile) afterLoad() {} + +// +checklocksignore +func (s *staticFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.DynamicBytesFile) + stateSourceObject.Load(1, &s.StaticData) +} + +func (i *InternalData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.InternalData" +} + +func (i *InternalData) StateFields() []string { + return []string{ + "Cgroups", + } +} + +func (i *InternalData) beforeSave() {} + +// +checklocksignore +func (i *InternalData) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.Cgroups) +} + +func (i *InternalData) afterLoad() {} + +// +checklocksignore +func (i *InternalData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.Cgroups) +} + +func (i *implStatFS) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.implStatFS" +} + +func (i *implStatFS) StateFields() []string { + return []string{} +} + +func (i *implStatFS) beforeSave() {} + +// +checklocksignore +func (i *implStatFS) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *implStatFS) afterLoad() {} + +// +checklocksignore +func (i *implStatFS) StateLoad(stateSourceObject state.Source) { +} + +func (i *subtasksInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.subtasksInode" +} + +func (i *subtasksInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + "subtasksInodeRefs", + "locks", + "fs", + "task", + "pidns", + "cgroupControllers", + } +} + +func (i *subtasksInode) beforeSave() {} + +// +checklocksignore +func (i *subtasksInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.implStatFS) + stateSinkObject.Save(1, &i.InodeAlwaysValid) + stateSinkObject.Save(2, &i.InodeAttrs) + stateSinkObject.Save(3, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(4, &i.InodeNotSymlink) + stateSinkObject.Save(5, &i.InodeTemporary) + stateSinkObject.Save(6, &i.OrderedChildren) + stateSinkObject.Save(7, &i.subtasksInodeRefs) + stateSinkObject.Save(8, &i.locks) + stateSinkObject.Save(9, &i.fs) + stateSinkObject.Save(10, &i.task) + stateSinkObject.Save(11, &i.pidns) + stateSinkObject.Save(12, &i.cgroupControllers) +} + +func (i *subtasksInode) afterLoad() {} + +// +checklocksignore +func (i *subtasksInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.implStatFS) + stateSourceObject.Load(1, &i.InodeAlwaysValid) + stateSourceObject.Load(2, &i.InodeAttrs) + stateSourceObject.Load(3, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(4, &i.InodeNotSymlink) + stateSourceObject.Load(5, &i.InodeTemporary) + stateSourceObject.Load(6, &i.OrderedChildren) + stateSourceObject.Load(7, &i.subtasksInodeRefs) + stateSourceObject.Load(8, &i.locks) + stateSourceObject.Load(9, &i.fs) + stateSourceObject.Load(10, &i.task) + stateSourceObject.Load(11, &i.pidns) + stateSourceObject.Load(12, &i.cgroupControllers) +} + +func (fd *subtasksFD) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.subtasksFD" +} + +func (fd *subtasksFD) StateFields() []string { + return []string{ + "GenericDirectoryFD", + "task", + } +} + +func (fd *subtasksFD) beforeSave() {} + +// +checklocksignore +func (fd *subtasksFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.GenericDirectoryFD) + stateSinkObject.Save(1, &fd.task) +} + +func (fd *subtasksFD) afterLoad() {} + +// +checklocksignore +func (fd *subtasksFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.GenericDirectoryFD) + stateSourceObject.Load(1, &fd.task) +} + +func (r *subtasksInodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.subtasksInodeRefs" +} + +func (r *subtasksInodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *subtasksInodeRefs) beforeSave() {} + +// +checklocksignore +func (r *subtasksInodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *subtasksInodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (i *taskInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.taskInode" +} + +func (i *taskInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + "taskInodeRefs", + "locks", + "task", + } +} + +func (i *taskInode) beforeSave() {} + +// +checklocksignore +func (i *taskInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.implStatFS) + stateSinkObject.Save(1, &i.InodeAttrs) + stateSinkObject.Save(2, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(3, &i.InodeNotSymlink) + stateSinkObject.Save(4, &i.InodeTemporary) + stateSinkObject.Save(5, &i.OrderedChildren) + stateSinkObject.Save(6, &i.taskInodeRefs) + stateSinkObject.Save(7, &i.locks) + stateSinkObject.Save(8, &i.task) +} + +func (i *taskInode) afterLoad() {} + +// +checklocksignore +func (i *taskInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.implStatFS) + stateSourceObject.Load(1, &i.InodeAttrs) + stateSourceObject.Load(2, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(3, &i.InodeNotSymlink) + stateSourceObject.Load(4, &i.InodeTemporary) + stateSourceObject.Load(5, &i.OrderedChildren) + stateSourceObject.Load(6, &i.taskInodeRefs) + stateSourceObject.Load(7, &i.locks) + stateSourceObject.Load(8, &i.task) +} + +func (i *taskOwnedInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.taskOwnedInode" +} + +func (i *taskOwnedInode) StateFields() []string { + return []string{ + "Inode", + "owner", + } +} + +func (i *taskOwnedInode) beforeSave() {} + +// +checklocksignore +func (i *taskOwnedInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.Inode) + stateSinkObject.Save(1, &i.owner) +} + +func (i *taskOwnedInode) afterLoad() {} + +// +checklocksignore +func (i *taskOwnedInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.Inode) + stateSourceObject.Load(1, &i.owner) +} + +func (i *fdDir) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdDir" +} + +func (i *fdDir) StateFields() []string { + return []string{ + "locks", + "fs", + "task", + "produceSymlink", + } +} + +func (i *fdDir) beforeSave() {} + +// +checklocksignore +func (i *fdDir) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.locks) + stateSinkObject.Save(1, &i.fs) + stateSinkObject.Save(2, &i.task) + stateSinkObject.Save(3, &i.produceSymlink) +} + +func (i *fdDir) afterLoad() {} + +// +checklocksignore +func (i *fdDir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.locks) + stateSourceObject.Load(1, &i.fs) + stateSourceObject.Load(2, &i.task) + stateSourceObject.Load(3, &i.produceSymlink) +} + +func (i *fdDirInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdDirInode" +} + +func (i *fdDirInode) StateFields() []string { + return []string{ + "fdDir", + "fdDirInodeRefs", + "implStatFS", + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + } +} + +func (i *fdDirInode) beforeSave() {} + +// +checklocksignore +func (i *fdDirInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.fdDir) + stateSinkObject.Save(1, &i.fdDirInodeRefs) + stateSinkObject.Save(2, &i.implStatFS) + stateSinkObject.Save(3, &i.InodeAlwaysValid) + stateSinkObject.Save(4, &i.InodeAttrs) + stateSinkObject.Save(5, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(6, &i.InodeNotSymlink) + stateSinkObject.Save(7, &i.InodeTemporary) + stateSinkObject.Save(8, &i.OrderedChildren) +} + +func (i *fdDirInode) afterLoad() {} + +// +checklocksignore +func (i *fdDirInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.fdDir) + stateSourceObject.Load(1, &i.fdDirInodeRefs) + stateSourceObject.Load(2, &i.implStatFS) + stateSourceObject.Load(3, &i.InodeAlwaysValid) + stateSourceObject.Load(4, &i.InodeAttrs) + stateSourceObject.Load(5, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(6, &i.InodeNotSymlink) + stateSourceObject.Load(7, &i.InodeTemporary) + stateSourceObject.Load(8, &i.OrderedChildren) +} + +func (s *fdSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdSymlink" +} + +func (s *fdSymlink) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeSymlink", + "fs", + "task", + "fd", + } +} + +func (s *fdSymlink) beforeSave() {} + +// +checklocksignore +func (s *fdSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.implStatFS) + stateSinkObject.Save(1, &s.InodeAttrs) + stateSinkObject.Save(2, &s.InodeNoopRefCount) + stateSinkObject.Save(3, &s.InodeSymlink) + stateSinkObject.Save(4, &s.fs) + stateSinkObject.Save(5, &s.task) + stateSinkObject.Save(6, &s.fd) +} + +func (s *fdSymlink) afterLoad() {} + +// +checklocksignore +func (s *fdSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.implStatFS) + stateSourceObject.Load(1, &s.InodeAttrs) + stateSourceObject.Load(2, &s.InodeNoopRefCount) + stateSourceObject.Load(3, &s.InodeSymlink) + stateSourceObject.Load(4, &s.fs) + stateSourceObject.Load(5, &s.task) + stateSourceObject.Load(6, &s.fd) +} + +func (i *fdInfoDirInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdInfoDirInode" +} + +func (i *fdInfoDirInode) StateFields() []string { + return []string{ + "fdDir", + "fdInfoDirInodeRefs", + "implStatFS", + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + } +} + +func (i *fdInfoDirInode) beforeSave() {} + +// +checklocksignore +func (i *fdInfoDirInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.fdDir) + stateSinkObject.Save(1, &i.fdInfoDirInodeRefs) + stateSinkObject.Save(2, &i.implStatFS) + stateSinkObject.Save(3, &i.InodeAlwaysValid) + stateSinkObject.Save(4, &i.InodeAttrs) + stateSinkObject.Save(5, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(6, &i.InodeNotSymlink) + stateSinkObject.Save(7, &i.InodeTemporary) + stateSinkObject.Save(8, &i.OrderedChildren) +} + +func (i *fdInfoDirInode) afterLoad() {} + +// +checklocksignore +func (i *fdInfoDirInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.fdDir) + stateSourceObject.Load(1, &i.fdInfoDirInodeRefs) + stateSourceObject.Load(2, &i.implStatFS) + stateSourceObject.Load(3, &i.InodeAlwaysValid) + stateSourceObject.Load(4, &i.InodeAttrs) + stateSourceObject.Load(5, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(6, &i.InodeNotSymlink) + stateSourceObject.Load(7, &i.InodeTemporary) + stateSourceObject.Load(8, &i.OrderedChildren) +} + +func (d *fdInfoData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.fdInfoData" +} + +func (d *fdInfoData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "fs", + "task", + "fd", + } +} + +func (d *fdInfoData) beforeSave() {} + +// +checklocksignore +func (d *fdInfoData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.fs) + stateSinkObject.Save(2, &d.task) + stateSinkObject.Save(3, &d.fd) +} + +func (d *fdInfoData) afterLoad() {} + +// +checklocksignore +func (d *fdInfoData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.fs) + stateSourceObject.Load(2, &d.task) + stateSourceObject.Load(3, &d.fd) +} + +func (d *auxvData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.auxvData" +} + +func (d *auxvData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (d *auxvData) beforeSave() {} + +// +checklocksignore +func (d *auxvData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.task) +} + +func (d *auxvData) afterLoad() {} + +// +checklocksignore +func (d *auxvData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.task) +} + +func (d *cmdlineData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.cmdlineData" +} + +func (d *cmdlineData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + "arg", + } +} + +func (d *cmdlineData) beforeSave() {} + +// +checklocksignore +func (d *cmdlineData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.task) + stateSinkObject.Save(2, &d.arg) +} + +func (d *cmdlineData) afterLoad() {} + +// +checklocksignore +func (d *cmdlineData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.task) + stateSourceObject.Load(2, &d.arg) +} + +func (i *commInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.commInode" +} + +func (i *commInode) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (i *commInode) beforeSave() {} + +// +checklocksignore +func (i *commInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.DynamicBytesFile) + stateSinkObject.Save(1, &i.task) +} + +func (i *commInode) afterLoad() {} + +// +checklocksignore +func (i *commInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.DynamicBytesFile) + stateSourceObject.Load(1, &i.task) +} + +func (d *commData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.commData" +} + +func (d *commData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (d *commData) beforeSave() {} + +// +checklocksignore +func (d *commData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.task) +} + +func (d *commData) afterLoad() {} + +// +checklocksignore +func (d *commData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.task) +} + +func (d *idMapData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.idMapData" +} + +func (d *idMapData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + "gids", + } +} + +func (d *idMapData) beforeSave() {} + +// +checklocksignore +func (d *idMapData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.task) + stateSinkObject.Save(2, &d.gids) +} + +func (d *idMapData) afterLoad() {} + +// +checklocksignore +func (d *idMapData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.task) + stateSourceObject.Load(2, &d.gids) +} + +func (f *memInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.memInode" +} + +func (f *memInode) StateFields() []string { + return []string{ + "InodeAttrs", + "InodeNoStatFS", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "task", + "locks", + } +} + +func (f *memInode) beforeSave() {} + +// +checklocksignore +func (f *memInode) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.InodeAttrs) + stateSinkObject.Save(1, &f.InodeNoStatFS) + stateSinkObject.Save(2, &f.InodeNoopRefCount) + stateSinkObject.Save(3, &f.InodeNotDirectory) + stateSinkObject.Save(4, &f.InodeNotSymlink) + stateSinkObject.Save(5, &f.task) + stateSinkObject.Save(6, &f.locks) +} + +func (f *memInode) afterLoad() {} + +// +checklocksignore +func (f *memInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.InodeAttrs) + stateSourceObject.Load(1, &f.InodeNoStatFS) + stateSourceObject.Load(2, &f.InodeNoopRefCount) + stateSourceObject.Load(3, &f.InodeNotDirectory) + stateSourceObject.Load(4, &f.InodeNotSymlink) + stateSourceObject.Load(5, &f.task) + stateSourceObject.Load(6, &f.locks) +} + +func (fd *memFD) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.memFD" +} + +func (fd *memFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + "inode", + "offset", + } +} + +func (fd *memFD) beforeSave() {} + +// +checklocksignore +func (fd *memFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.LockFD) + stateSinkObject.Save(3, &fd.inode) + stateSinkObject.Save(4, &fd.offset) +} + +func (fd *memFD) afterLoad() {} + +// +checklocksignore +func (fd *memFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.LockFD) + stateSourceObject.Load(3, &fd.inode) + stateSourceObject.Load(4, &fd.offset) +} + +func (d *mapsData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.mapsData" +} + +func (d *mapsData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (d *mapsData) beforeSave() {} + +// +checklocksignore +func (d *mapsData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.task) +} + +func (d *mapsData) afterLoad() {} + +// +checklocksignore +func (d *mapsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.task) +} + +func (d *smapsData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.smapsData" +} + +func (d *smapsData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (d *smapsData) beforeSave() {} + +// +checklocksignore +func (d *smapsData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.task) +} + +func (d *smapsData) afterLoad() {} + +// +checklocksignore +func (d *smapsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.task) +} + +func (s *taskStatData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.taskStatData" +} + +func (s *taskStatData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + "tgstats", + "pidns", + } +} + +func (s *taskStatData) beforeSave() {} + +// +checklocksignore +func (s *taskStatData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.DynamicBytesFile) + stateSinkObject.Save(1, &s.task) + stateSinkObject.Save(2, &s.tgstats) + stateSinkObject.Save(3, &s.pidns) +} + +func (s *taskStatData) afterLoad() {} + +// +checklocksignore +func (s *taskStatData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.DynamicBytesFile) + stateSourceObject.Load(1, &s.task) + stateSourceObject.Load(2, &s.tgstats) + stateSourceObject.Load(3, &s.pidns) +} + +func (s *statmData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.statmData" +} + +func (s *statmData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (s *statmData) beforeSave() {} + +// +checklocksignore +func (s *statmData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.DynamicBytesFile) + stateSinkObject.Save(1, &s.task) +} + +func (s *statmData) afterLoad() {} + +// +checklocksignore +func (s *statmData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.DynamicBytesFile) + stateSourceObject.Load(1, &s.task) +} + +func (s *statusInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.statusInode" +} + +func (s *statusInode) StateFields() []string { + return []string{ + "InodeAttrs", + "InodeNoStatFS", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "task", + "pidns", + "locks", + } +} + +func (s *statusInode) beforeSave() {} + +// +checklocksignore +func (s *statusInode) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.InodeAttrs) + stateSinkObject.Save(1, &s.InodeNoStatFS) + stateSinkObject.Save(2, &s.InodeNoopRefCount) + stateSinkObject.Save(3, &s.InodeNotDirectory) + stateSinkObject.Save(4, &s.InodeNotSymlink) + stateSinkObject.Save(5, &s.task) + stateSinkObject.Save(6, &s.pidns) + stateSinkObject.Save(7, &s.locks) +} + +func (s *statusInode) afterLoad() {} + +// +checklocksignore +func (s *statusInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.InodeAttrs) + stateSourceObject.Load(1, &s.InodeNoStatFS) + stateSourceObject.Load(2, &s.InodeNoopRefCount) + stateSourceObject.Load(3, &s.InodeNotDirectory) + stateSourceObject.Load(4, &s.InodeNotSymlink) + stateSourceObject.Load(5, &s.task) + stateSourceObject.Load(6, &s.pidns) + stateSourceObject.Load(7, &s.locks) +} + +func (s *statusFD) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.statusFD" +} + +func (s *statusFD) StateFields() []string { + return []string{ + "statusFDLowerBase", + "DynamicBytesFileDescriptionImpl", + "LockFD", + "vfsfd", + "inode", + "task", + "pidns", + "userns", + } +} + +func (s *statusFD) beforeSave() {} + +// +checklocksignore +func (s *statusFD) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.statusFDLowerBase) + stateSinkObject.Save(1, &s.DynamicBytesFileDescriptionImpl) + stateSinkObject.Save(2, &s.LockFD) + stateSinkObject.Save(3, &s.vfsfd) + stateSinkObject.Save(4, &s.inode) + stateSinkObject.Save(5, &s.task) + stateSinkObject.Save(6, &s.pidns) + stateSinkObject.Save(7, &s.userns) +} + +func (s *statusFD) afterLoad() {} + +// +checklocksignore +func (s *statusFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.statusFDLowerBase) + stateSourceObject.Load(1, &s.DynamicBytesFileDescriptionImpl) + stateSourceObject.Load(2, &s.LockFD) + stateSourceObject.Load(3, &s.vfsfd) + stateSourceObject.Load(4, &s.inode) + stateSourceObject.Load(5, &s.task) + stateSourceObject.Load(6, &s.pidns) + stateSourceObject.Load(7, &s.userns) +} + +func (s *statusFDLowerBase) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.statusFDLowerBase" +} + +func (s *statusFDLowerBase) StateFields() []string { + return []string{ + "FileDescriptionDefaultImpl", + } +} + +func (s *statusFDLowerBase) beforeSave() {} + +// +checklocksignore +func (s *statusFDLowerBase) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.FileDescriptionDefaultImpl) +} + +func (s *statusFDLowerBase) afterLoad() {} + +// +checklocksignore +func (s *statusFDLowerBase) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.FileDescriptionDefaultImpl) +} + +func (i *ioData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.ioData" +} + +func (i *ioData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "ioUsage", + } +} + +func (i *ioData) beforeSave() {} + +// +checklocksignore +func (i *ioData) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.DynamicBytesFile) + stateSinkObject.Save(1, &i.ioUsage) +} + +func (i *ioData) afterLoad() {} + +// +checklocksignore +func (i *ioData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.DynamicBytesFile) + stateSourceObject.Load(1, &i.ioUsage) +} + +func (o *oomScoreAdj) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.oomScoreAdj" +} + +func (o *oomScoreAdj) StateFields() []string { + return []string{ + "DynamicBytesFile", + "task", + } +} + +func (o *oomScoreAdj) beforeSave() {} + +// +checklocksignore +func (o *oomScoreAdj) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.DynamicBytesFile) + stateSinkObject.Save(1, &o.task) +} + +func (o *oomScoreAdj) afterLoad() {} + +// +checklocksignore +func (o *oomScoreAdj) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.DynamicBytesFile) + stateSourceObject.Load(1, &o.task) +} + +func (s *exeSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.exeSymlink" +} + +func (s *exeSymlink) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeSymlink", + "fs", + "task", + } +} + +func (s *exeSymlink) beforeSave() {} + +// +checklocksignore +func (s *exeSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.implStatFS) + stateSinkObject.Save(1, &s.InodeAttrs) + stateSinkObject.Save(2, &s.InodeNoopRefCount) + stateSinkObject.Save(3, &s.InodeSymlink) + stateSinkObject.Save(4, &s.fs) + stateSinkObject.Save(5, &s.task) +} + +func (s *exeSymlink) afterLoad() {} + +// +checklocksignore +func (s *exeSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.implStatFS) + stateSourceObject.Load(1, &s.InodeAttrs) + stateSourceObject.Load(2, &s.InodeNoopRefCount) + stateSourceObject.Load(3, &s.InodeSymlink) + stateSourceObject.Load(4, &s.fs) + stateSourceObject.Load(5, &s.task) +} + +func (s *cwdSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.cwdSymlink" +} + +func (s *cwdSymlink) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeSymlink", + "fs", + "task", + } +} + +func (s *cwdSymlink) beforeSave() {} + +// +checklocksignore +func (s *cwdSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.implStatFS) + stateSinkObject.Save(1, &s.InodeAttrs) + stateSinkObject.Save(2, &s.InodeNoopRefCount) + stateSinkObject.Save(3, &s.InodeSymlink) + stateSinkObject.Save(4, &s.fs) + stateSinkObject.Save(5, &s.task) +} + +func (s *cwdSymlink) afterLoad() {} + +// +checklocksignore +func (s *cwdSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.implStatFS) + stateSourceObject.Load(1, &s.InodeAttrs) + stateSourceObject.Load(2, &s.InodeNoopRefCount) + stateSourceObject.Load(3, &s.InodeSymlink) + stateSourceObject.Load(4, &s.fs) + stateSourceObject.Load(5, &s.task) +} + +func (i *mountInfoData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.mountInfoData" +} + +func (i *mountInfoData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "fs", + "task", + } +} + +func (i *mountInfoData) beforeSave() {} + +// +checklocksignore +func (i *mountInfoData) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.DynamicBytesFile) + stateSinkObject.Save(1, &i.fs) + stateSinkObject.Save(2, &i.task) +} + +func (i *mountInfoData) afterLoad() {} + +// +checklocksignore +func (i *mountInfoData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.DynamicBytesFile) + stateSourceObject.Load(1, &i.fs) + stateSourceObject.Load(2, &i.task) +} + +func (i *mountsData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.mountsData" +} + +func (i *mountsData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "fs", + "task", + } +} + +func (i *mountsData) beforeSave() {} + +// +checklocksignore +func (i *mountsData) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.DynamicBytesFile) + stateSinkObject.Save(1, &i.fs) + stateSinkObject.Save(2, &i.task) +} + +func (i *mountsData) afterLoad() {} + +// +checklocksignore +func (i *mountsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.DynamicBytesFile) + stateSourceObject.Load(1, &i.fs) + stateSourceObject.Load(2, &i.task) +} + +func (s *namespaceSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.namespaceSymlink" +} + +func (s *namespaceSymlink) StateFields() []string { + return []string{ + "StaticSymlink", + "task", + } +} + +func (s *namespaceSymlink) beforeSave() {} + +// +checklocksignore +func (s *namespaceSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.StaticSymlink) + stateSinkObject.Save(1, &s.task) +} + +func (s *namespaceSymlink) afterLoad() {} + +// +checklocksignore +func (s *namespaceSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.StaticSymlink) + stateSourceObject.Load(1, &s.task) +} + +func (i *namespaceInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.namespaceInode" +} + +func (i *namespaceInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "locks", + } +} + +func (i *namespaceInode) beforeSave() {} + +// +checklocksignore +func (i *namespaceInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.implStatFS) + stateSinkObject.Save(1, &i.InodeAttrs) + stateSinkObject.Save(2, &i.InodeNoopRefCount) + stateSinkObject.Save(3, &i.InodeNotDirectory) + stateSinkObject.Save(4, &i.InodeNotSymlink) + stateSinkObject.Save(5, &i.locks) +} + +func (i *namespaceInode) afterLoad() {} + +// +checklocksignore +func (i *namespaceInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.implStatFS) + stateSourceObject.Load(1, &i.InodeAttrs) + stateSourceObject.Load(2, &i.InodeNoopRefCount) + stateSourceObject.Load(3, &i.InodeNotDirectory) + stateSourceObject.Load(4, &i.InodeNotSymlink) + stateSourceObject.Load(5, &i.locks) +} + +func (fd *namespaceFD) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.namespaceFD" +} + +func (fd *namespaceFD) StateFields() []string { + return []string{ + "FileDescriptionDefaultImpl", + "LockFD", + "vfsfd", + "inode", + } +} + +func (fd *namespaceFD) beforeSave() {} + +// +checklocksignore +func (fd *namespaceFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(1, &fd.LockFD) + stateSinkObject.Save(2, &fd.vfsfd) + stateSinkObject.Save(3, &fd.inode) +} + +func (fd *namespaceFD) afterLoad() {} + +// +checklocksignore +func (fd *namespaceFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(1, &fd.LockFD) + stateSourceObject.Load(2, &fd.vfsfd) + stateSourceObject.Load(3, &fd.inode) +} + +func (d *taskCgroupData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.taskCgroupData" +} + +func (d *taskCgroupData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + "task", + } +} + +func (d *taskCgroupData) beforeSave() {} + +// +checklocksignore +func (d *taskCgroupData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.dynamicBytesFileSetAttr) + stateSinkObject.Save(1, &d.task) +} + +func (d *taskCgroupData) afterLoad() {} + +// +checklocksignore +func (d *taskCgroupData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.dynamicBytesFileSetAttr) + stateSourceObject.Load(1, &d.task) +} + +func (r *taskInodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.taskInodeRefs" +} + +func (r *taskInodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *taskInodeRefs) beforeSave() {} + +// +checklocksignore +func (r *taskInodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *taskInodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (n *ifinet6) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.ifinet6" +} + +func (n *ifinet6) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + } +} + +func (n *ifinet6) beforeSave() {} + +// +checklocksignore +func (n *ifinet6) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.DynamicBytesFile) + stateSinkObject.Save(1, &n.stack) +} + +func (n *ifinet6) afterLoad() {} + +// +checklocksignore +func (n *ifinet6) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.DynamicBytesFile) + stateSourceObject.Load(1, &n.stack) +} + +func (n *netDevData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netDevData" +} + +func (n *netDevData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + } +} + +func (n *netDevData) beforeSave() {} + +// +checklocksignore +func (n *netDevData) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.DynamicBytesFile) + stateSinkObject.Save(1, &n.stack) +} + +func (n *netDevData) afterLoad() {} + +// +checklocksignore +func (n *netDevData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.DynamicBytesFile) + stateSourceObject.Load(1, &n.stack) +} + +func (n *netUnixData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netUnixData" +} + +func (n *netUnixData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "kernel", + } +} + +func (n *netUnixData) beforeSave() {} + +// +checklocksignore +func (n *netUnixData) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.DynamicBytesFile) + stateSinkObject.Save(1, &n.kernel) +} + +func (n *netUnixData) afterLoad() {} + +// +checklocksignore +func (n *netUnixData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.DynamicBytesFile) + stateSourceObject.Load(1, &n.kernel) +} + +func (d *netTCPData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netTCPData" +} + +func (d *netTCPData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "kernel", + } +} + +func (d *netTCPData) beforeSave() {} + +// +checklocksignore +func (d *netTCPData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.kernel) +} + +func (d *netTCPData) afterLoad() {} + +// +checklocksignore +func (d *netTCPData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.kernel) +} + +func (d *netTCP6Data) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netTCP6Data" +} + +func (d *netTCP6Data) StateFields() []string { + return []string{ + "DynamicBytesFile", + "kernel", + } +} + +func (d *netTCP6Data) beforeSave() {} + +// +checklocksignore +func (d *netTCP6Data) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.kernel) +} + +func (d *netTCP6Data) afterLoad() {} + +// +checklocksignore +func (d *netTCP6Data) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.kernel) +} + +func (d *netUDPData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netUDPData" +} + +func (d *netUDPData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "kernel", + } +} + +func (d *netUDPData) beforeSave() {} + +// +checklocksignore +func (d *netUDPData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.kernel) +} + +func (d *netUDPData) afterLoad() {} + +// +checklocksignore +func (d *netUDPData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.kernel) +} + +func (d *netSnmpData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netSnmpData" +} + +func (d *netSnmpData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + } +} + +func (d *netSnmpData) beforeSave() {} + +// +checklocksignore +func (d *netSnmpData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.stack) +} + +func (d *netSnmpData) afterLoad() {} + +// +checklocksignore +func (d *netSnmpData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.stack) +} + +func (s *snmpLine) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.snmpLine" +} + +func (s *snmpLine) StateFields() []string { + return []string{ + "prefix", + "header", + } +} + +func (s *snmpLine) beforeSave() {} + +// +checklocksignore +func (s *snmpLine) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.prefix) + stateSinkObject.Save(1, &s.header) +} + +func (s *snmpLine) afterLoad() {} + +// +checklocksignore +func (s *snmpLine) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.prefix) + stateSourceObject.Load(1, &s.header) +} + +func (d *netRouteData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netRouteData" +} + +func (d *netRouteData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + } +} + +func (d *netRouteData) beforeSave() {} + +// +checklocksignore +func (d *netRouteData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.stack) +} + +func (d *netRouteData) afterLoad() {} + +// +checklocksignore +func (d *netRouteData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.stack) +} + +func (d *netStatData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.netStatData" +} + +func (d *netStatData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + } +} + +func (d *netStatData) beforeSave() {} + +// +checklocksignore +func (d *netStatData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.stack) +} + +func (d *netStatData) afterLoad() {} + +// +checklocksignore +func (d *netStatData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.stack) +} + +func (i *tasksInode) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.tasksInode" +} + +func (i *tasksInode) StateFields() []string { + return []string{ + "implStatFS", + "InodeAlwaysValid", + "InodeAttrs", + "InodeDirectoryNoNewChildren", + "InodeNotSymlink", + "InodeTemporary", + "OrderedChildren", + "tasksInodeRefs", + "locks", + "fs", + "pidns", + "fakeCgroupControllers", + } +} + +func (i *tasksInode) beforeSave() {} + +// +checklocksignore +func (i *tasksInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.implStatFS) + stateSinkObject.Save(1, &i.InodeAlwaysValid) + stateSinkObject.Save(2, &i.InodeAttrs) + stateSinkObject.Save(3, &i.InodeDirectoryNoNewChildren) + stateSinkObject.Save(4, &i.InodeNotSymlink) + stateSinkObject.Save(5, &i.InodeTemporary) + stateSinkObject.Save(6, &i.OrderedChildren) + stateSinkObject.Save(7, &i.tasksInodeRefs) + stateSinkObject.Save(8, &i.locks) + stateSinkObject.Save(9, &i.fs) + stateSinkObject.Save(10, &i.pidns) + stateSinkObject.Save(11, &i.fakeCgroupControllers) +} + +func (i *tasksInode) afterLoad() {} + +// +checklocksignore +func (i *tasksInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.implStatFS) + stateSourceObject.Load(1, &i.InodeAlwaysValid) + stateSourceObject.Load(2, &i.InodeAttrs) + stateSourceObject.Load(3, &i.InodeDirectoryNoNewChildren) + stateSourceObject.Load(4, &i.InodeNotSymlink) + stateSourceObject.Load(5, &i.InodeTemporary) + stateSourceObject.Load(6, &i.OrderedChildren) + stateSourceObject.Load(7, &i.tasksInodeRefs) + stateSourceObject.Load(8, &i.locks) + stateSourceObject.Load(9, &i.fs) + stateSourceObject.Load(10, &i.pidns) + stateSourceObject.Load(11, &i.fakeCgroupControllers) +} + +func (s *staticFileSetStat) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.staticFileSetStat" +} + +func (s *staticFileSetStat) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + "StaticData", + } +} + +func (s *staticFileSetStat) beforeSave() {} + +// +checklocksignore +func (s *staticFileSetStat) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.dynamicBytesFileSetAttr) + stateSinkObject.Save(1, &s.StaticData) +} + +func (s *staticFileSetStat) afterLoad() {} + +// +checklocksignore +func (s *staticFileSetStat) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.dynamicBytesFileSetAttr) + stateSourceObject.Load(1, &s.StaticData) +} + +func (s *selfSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.selfSymlink" +} + +func (s *selfSymlink) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeSymlink", + "pidns", + } +} + +func (s *selfSymlink) beforeSave() {} + +// +checklocksignore +func (s *selfSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.implStatFS) + stateSinkObject.Save(1, &s.InodeAttrs) + stateSinkObject.Save(2, &s.InodeNoopRefCount) + stateSinkObject.Save(3, &s.InodeSymlink) + stateSinkObject.Save(4, &s.pidns) +} + +func (s *selfSymlink) afterLoad() {} + +// +checklocksignore +func (s *selfSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.implStatFS) + stateSourceObject.Load(1, &s.InodeAttrs) + stateSourceObject.Load(2, &s.InodeNoopRefCount) + stateSourceObject.Load(3, &s.InodeSymlink) + stateSourceObject.Load(4, &s.pidns) +} + +func (s *threadSelfSymlink) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.threadSelfSymlink" +} + +func (s *threadSelfSymlink) StateFields() []string { + return []string{ + "implStatFS", + "InodeAttrs", + "InodeNoopRefCount", + "InodeSymlink", + "pidns", + } +} + +func (s *threadSelfSymlink) beforeSave() {} + +// +checklocksignore +func (s *threadSelfSymlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.implStatFS) + stateSinkObject.Save(1, &s.InodeAttrs) + stateSinkObject.Save(2, &s.InodeNoopRefCount) + stateSinkObject.Save(3, &s.InodeSymlink) + stateSinkObject.Save(4, &s.pidns) +} + +func (s *threadSelfSymlink) afterLoad() {} + +// +checklocksignore +func (s *threadSelfSymlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.implStatFS) + stateSourceObject.Load(1, &s.InodeAttrs) + stateSourceObject.Load(2, &s.InodeNoopRefCount) + stateSourceObject.Load(3, &s.InodeSymlink) + stateSourceObject.Load(4, &s.pidns) +} + +func (d *dynamicBytesFileSetAttr) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.dynamicBytesFileSetAttr" +} + +func (d *dynamicBytesFileSetAttr) StateFields() []string { + return []string{ + "DynamicBytesFile", + } +} + +func (d *dynamicBytesFileSetAttr) beforeSave() {} + +// +checklocksignore +func (d *dynamicBytesFileSetAttr) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) +} + +func (d *dynamicBytesFileSetAttr) afterLoad() {} + +// +checklocksignore +func (d *dynamicBytesFileSetAttr) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) +} + +func (c *cpuStats) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.cpuStats" +} + +func (c *cpuStats) StateFields() []string { + return []string{ + "user", + "nice", + "system", + "idle", + "ioWait", + "irq", + "softirq", + "steal", + "guest", + "guestNice", + } +} + +func (c *cpuStats) beforeSave() {} + +// +checklocksignore +func (c *cpuStats) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.user) + stateSinkObject.Save(1, &c.nice) + stateSinkObject.Save(2, &c.system) + stateSinkObject.Save(3, &c.idle) + stateSinkObject.Save(4, &c.ioWait) + stateSinkObject.Save(5, &c.irq) + stateSinkObject.Save(6, &c.softirq) + stateSinkObject.Save(7, &c.steal) + stateSinkObject.Save(8, &c.guest) + stateSinkObject.Save(9, &c.guestNice) +} + +func (c *cpuStats) afterLoad() {} + +// +checklocksignore +func (c *cpuStats) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.user) + stateSourceObject.Load(1, &c.nice) + stateSourceObject.Load(2, &c.system) + stateSourceObject.Load(3, &c.idle) + stateSourceObject.Load(4, &c.ioWait) + stateSourceObject.Load(5, &c.irq) + stateSourceObject.Load(6, &c.softirq) + stateSourceObject.Load(7, &c.steal) + stateSourceObject.Load(8, &c.guest) + stateSourceObject.Load(9, &c.guestNice) +} + +func (s *statData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.statData" +} + +func (s *statData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (s *statData) beforeSave() {} + +// +checklocksignore +func (s *statData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.dynamicBytesFileSetAttr) +} + +func (s *statData) afterLoad() {} + +// +checklocksignore +func (s *statData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.dynamicBytesFileSetAttr) +} + +func (l *loadavgData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.loadavgData" +} + +func (l *loadavgData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (l *loadavgData) beforeSave() {} + +// +checklocksignore +func (l *loadavgData) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.dynamicBytesFileSetAttr) +} + +func (l *loadavgData) afterLoad() {} + +// +checklocksignore +func (l *loadavgData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.dynamicBytesFileSetAttr) +} + +func (m *meminfoData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.meminfoData" +} + +func (m *meminfoData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (m *meminfoData) beforeSave() {} + +// +checklocksignore +func (m *meminfoData) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.dynamicBytesFileSetAttr) +} + +func (m *meminfoData) afterLoad() {} + +// +checklocksignore +func (m *meminfoData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.dynamicBytesFileSetAttr) +} + +func (u *uptimeData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.uptimeData" +} + +func (u *uptimeData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (u *uptimeData) beforeSave() {} + +// +checklocksignore +func (u *uptimeData) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.dynamicBytesFileSetAttr) +} + +func (u *uptimeData) afterLoad() {} + +// +checklocksignore +func (u *uptimeData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.dynamicBytesFileSetAttr) +} + +func (v *versionData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.versionData" +} + +func (v *versionData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (v *versionData) beforeSave() {} + +// +checklocksignore +func (v *versionData) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + stateSinkObject.Save(0, &v.dynamicBytesFileSetAttr) +} + +func (v *versionData) afterLoad() {} + +// +checklocksignore +func (v *versionData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.dynamicBytesFileSetAttr) +} + +func (d *filesystemsData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.filesystemsData" +} + +func (d *filesystemsData) StateFields() []string { + return []string{ + "DynamicBytesFile", + } +} + +func (d *filesystemsData) beforeSave() {} + +// +checklocksignore +func (d *filesystemsData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) +} + +func (d *filesystemsData) afterLoad() {} + +// +checklocksignore +func (d *filesystemsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) +} + +func (c *cgroupsData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.cgroupsData" +} + +func (c *cgroupsData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (c *cgroupsData) beforeSave() {} + +// +checklocksignore +func (c *cgroupsData) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.dynamicBytesFileSetAttr) +} + +func (c *cgroupsData) afterLoad() {} + +// +checklocksignore +func (c *cgroupsData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.dynamicBytesFileSetAttr) +} + +func (c *cmdLineData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.cmdLineData" +} + +func (c *cmdLineData) StateFields() []string { + return []string{ + "dynamicBytesFileSetAttr", + } +} + +func (c *cmdLineData) beforeSave() {} + +// +checklocksignore +func (c *cmdLineData) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.dynamicBytesFileSetAttr) +} + +func (c *cmdLineData) afterLoad() {} + +// +checklocksignore +func (c *cmdLineData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.dynamicBytesFileSetAttr) +} + +func (r *tasksInodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.tasksInodeRefs" +} + +func (r *tasksInodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *tasksInodeRefs) beforeSave() {} + +// +checklocksignore +func (r *tasksInodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *tasksInodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (t *tcpMemDir) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.tcpMemDir" +} + +func (t *tcpMemDir) StateFields() []string { + return nil +} + +func (d *mmapMinAddrData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.mmapMinAddrData" +} + +func (d *mmapMinAddrData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "k", + } +} + +func (d *mmapMinAddrData) beforeSave() {} + +// +checklocksignore +func (d *mmapMinAddrData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.k) +} + +func (d *mmapMinAddrData) afterLoad() {} + +// +checklocksignore +func (d *mmapMinAddrData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.k) +} + +func (h *hostnameData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.hostnameData" +} + +func (h *hostnameData) StateFields() []string { + return []string{ + "DynamicBytesFile", + } +} + +func (h *hostnameData) beforeSave() {} + +// +checklocksignore +func (h *hostnameData) StateSave(stateSinkObject state.Sink) { + h.beforeSave() + stateSinkObject.Save(0, &h.DynamicBytesFile) +} + +func (h *hostnameData) afterLoad() {} + +// +checklocksignore +func (h *hostnameData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &h.DynamicBytesFile) +} + +func (d *tcpSackData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.tcpSackData" +} + +func (d *tcpSackData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + "enabled", + } +} + +func (d *tcpSackData) beforeSave() {} + +// +checklocksignore +func (d *tcpSackData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.stack) + stateSinkObject.Save(2, &d.enabled) +} + +func (d *tcpSackData) afterLoad() {} + +// +checklocksignore +func (d *tcpSackData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.LoadWait(1, &d.stack) + stateSourceObject.Load(2, &d.enabled) +} + +func (d *tcpRecoveryData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.tcpRecoveryData" +} + +func (d *tcpRecoveryData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + } +} + +func (d *tcpRecoveryData) beforeSave() {} + +// +checklocksignore +func (d *tcpRecoveryData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.stack) +} + +func (d *tcpRecoveryData) afterLoad() {} + +// +checklocksignore +func (d *tcpRecoveryData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.LoadWait(1, &d.stack) +} + +func (d *tcpMemData) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.tcpMemData" +} + +func (d *tcpMemData) StateFields() []string { + return []string{ + "DynamicBytesFile", + "dir", + "stack", + } +} + +func (d *tcpMemData) beforeSave() {} + +// +checklocksignore +func (d *tcpMemData) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.DynamicBytesFile) + stateSinkObject.Save(1, &d.dir) + stateSinkObject.Save(2, &d.stack) +} + +func (d *tcpMemData) afterLoad() {} + +// +checklocksignore +func (d *tcpMemData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.DynamicBytesFile) + stateSourceObject.Load(1, &d.dir) + stateSourceObject.LoadWait(2, &d.stack) +} + +func (ipf *ipForwarding) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.ipForwarding" +} + +func (ipf *ipForwarding) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + "enabled", + } +} + +func (ipf *ipForwarding) beforeSave() {} + +// +checklocksignore +func (ipf *ipForwarding) StateSave(stateSinkObject state.Sink) { + ipf.beforeSave() + stateSinkObject.Save(0, &ipf.DynamicBytesFile) + stateSinkObject.Save(1, &ipf.stack) + stateSinkObject.Save(2, &ipf.enabled) +} + +func (ipf *ipForwarding) afterLoad() {} + +// +checklocksignore +func (ipf *ipForwarding) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ipf.DynamicBytesFile) + stateSourceObject.LoadWait(1, &ipf.stack) + stateSourceObject.Load(2, &ipf.enabled) +} + +func (pr *portRange) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.portRange" +} + +func (pr *portRange) StateFields() []string { + return []string{ + "DynamicBytesFile", + "stack", + "start", + "end", + } +} + +func (pr *portRange) beforeSave() {} + +// +checklocksignore +func (pr *portRange) StateSave(stateSinkObject state.Sink) { + pr.beforeSave() + stateSinkObject.Save(0, &pr.DynamicBytesFile) + stateSinkObject.Save(1, &pr.stack) + stateSinkObject.Save(2, &pr.start) + stateSinkObject.Save(3, &pr.end) +} + +func (pr *portRange) afterLoad() {} + +// +checklocksignore +func (pr *portRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &pr.DynamicBytesFile) + stateSourceObject.LoadWait(1, &pr.stack) + stateSourceObject.Load(2, &pr.start) + stateSourceObject.Load(3, &pr.end) +} + +func (s *yamaPtraceScope) StateTypeName() string { + return "pkg/sentry/fsimpl/proc.yamaPtraceScope" +} + +func (s *yamaPtraceScope) StateFields() []string { + return []string{ + "DynamicBytesFile", + "level", + } +} + +func (s *yamaPtraceScope) beforeSave() {} + +// +checklocksignore +func (s *yamaPtraceScope) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.DynamicBytesFile) + stateSinkObject.Save(1, &s.level) +} + +func (s *yamaPtraceScope) afterLoad() {} + +// +checklocksignore +func (s *yamaPtraceScope) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.DynamicBytesFile) + stateSourceObject.Load(1, &s.level) +} + +func init() { + state.Register((*fdDirInodeRefs)(nil)) + state.Register((*fdInfoDirInodeRefs)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*staticFile)(nil)) + state.Register((*InternalData)(nil)) + state.Register((*implStatFS)(nil)) + state.Register((*subtasksInode)(nil)) + state.Register((*subtasksFD)(nil)) + state.Register((*subtasksInodeRefs)(nil)) + state.Register((*taskInode)(nil)) + state.Register((*taskOwnedInode)(nil)) + state.Register((*fdDir)(nil)) + state.Register((*fdDirInode)(nil)) + state.Register((*fdSymlink)(nil)) + state.Register((*fdInfoDirInode)(nil)) + state.Register((*fdInfoData)(nil)) + state.Register((*auxvData)(nil)) + state.Register((*cmdlineData)(nil)) + state.Register((*commInode)(nil)) + state.Register((*commData)(nil)) + state.Register((*idMapData)(nil)) + state.Register((*memInode)(nil)) + state.Register((*memFD)(nil)) + state.Register((*mapsData)(nil)) + state.Register((*smapsData)(nil)) + state.Register((*taskStatData)(nil)) + state.Register((*statmData)(nil)) + state.Register((*statusInode)(nil)) + state.Register((*statusFD)(nil)) + state.Register((*statusFDLowerBase)(nil)) + state.Register((*ioData)(nil)) + state.Register((*oomScoreAdj)(nil)) + state.Register((*exeSymlink)(nil)) + state.Register((*cwdSymlink)(nil)) + state.Register((*mountInfoData)(nil)) + state.Register((*mountsData)(nil)) + state.Register((*namespaceSymlink)(nil)) + state.Register((*namespaceInode)(nil)) + state.Register((*namespaceFD)(nil)) + state.Register((*taskCgroupData)(nil)) + state.Register((*taskInodeRefs)(nil)) + state.Register((*ifinet6)(nil)) + state.Register((*netDevData)(nil)) + state.Register((*netUnixData)(nil)) + state.Register((*netTCPData)(nil)) + state.Register((*netTCP6Data)(nil)) + state.Register((*netUDPData)(nil)) + state.Register((*netSnmpData)(nil)) + state.Register((*snmpLine)(nil)) + state.Register((*netRouteData)(nil)) + state.Register((*netStatData)(nil)) + state.Register((*tasksInode)(nil)) + state.Register((*staticFileSetStat)(nil)) + state.Register((*selfSymlink)(nil)) + state.Register((*threadSelfSymlink)(nil)) + state.Register((*dynamicBytesFileSetAttr)(nil)) + state.Register((*cpuStats)(nil)) + state.Register((*statData)(nil)) + state.Register((*loadavgData)(nil)) + state.Register((*meminfoData)(nil)) + state.Register((*uptimeData)(nil)) + state.Register((*versionData)(nil)) + state.Register((*filesystemsData)(nil)) + state.Register((*cgroupsData)(nil)) + state.Register((*cmdLineData)(nil)) + state.Register((*tasksInodeRefs)(nil)) + state.Register((*tcpMemDir)(nil)) + state.Register((*mmapMinAddrData)(nil)) + state.Register((*hostnameData)(nil)) + state.Register((*tcpSackData)(nil)) + state.Register((*tcpRecoveryData)(nil)) + state.Register((*tcpMemData)(nil)) + state.Register((*ipForwarding)(nil)) + state.Register((*portRange)(nil)) + state.Register((*yamaPtraceScope)(nil)) +} diff --git a/pkg/sentry/fsimpl/proc/subtasks_inode_refs.go b/pkg/sentry/fsimpl/proc/subtasks_inode_refs.go new file mode 100644 index 000000000..bd4998cbc --- /dev/null +++ b/pkg/sentry/fsimpl/proc/subtasks_inode_refs.go @@ -0,0 +1,140 @@ +package proc + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const subtasksInodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var subtasksInodeobj *subtasksInode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type subtasksInodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *subtasksInodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *subtasksInodeRefs) RefType() string { + return fmt.Sprintf("%T", subtasksInodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *subtasksInodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *subtasksInodeRefs) LogRefs() bool { + return subtasksInodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *subtasksInodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *subtasksInodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if subtasksInodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *subtasksInodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if subtasksInodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *subtasksInodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if subtasksInodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *subtasksInodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/proc/task_inode_refs.go b/pkg/sentry/fsimpl/proc/task_inode_refs.go new file mode 100644 index 000000000..82c63213a --- /dev/null +++ b/pkg/sentry/fsimpl/proc/task_inode_refs.go @@ -0,0 +1,140 @@ +package proc + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const taskInodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var taskInodeobj *taskInode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type taskInodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *taskInodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *taskInodeRefs) RefType() string { + return fmt.Sprintf("%T", taskInodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *taskInodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *taskInodeRefs) LogRefs() bool { + return taskInodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *taskInodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *taskInodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if taskInodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *taskInodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if taskInodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *taskInodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if taskInodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *taskInodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/proc/tasks_inode_refs.go b/pkg/sentry/fsimpl/proc/tasks_inode_refs.go new file mode 100644 index 000000000..73adc5278 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/tasks_inode_refs.go @@ -0,0 +1,140 @@ +package proc + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const tasksInodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var tasksInodeobj *tasksInode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type tasksInodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *tasksInodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *tasksInodeRefs) RefType() string { + return fmt.Sprintf("%T", tasksInodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *tasksInodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *tasksInodeRefs) LogRefs() bool { + return tasksInodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *tasksInodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *tasksInodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if tasksInodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *tasksInodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if tasksInodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *tasksInodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if tasksInodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *tasksInodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go deleted file mode 100644 index 19b012f7d..000000000 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "bytes" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/usermem" -) - -func newIPv6TestStack() *inet.TestStack { - s := inet.NewTestStack() - s.SupportsIPv6Flag = true - return s -} - -func TestIfinet6NoAddresses(t *testing.T) { - n := &ifinet6{stack: newIPv6TestStack()} - var buf bytes.Buffer - n.Generate(contexttest.Context(t), &buf) - if buf.Len() > 0 { - t.Errorf("n.Generate() generated = %v, want = %v", buf.Bytes(), []byte{}) - } -} - -func TestIfinet6(t *testing.T) { - s := newIPv6TestStack() - s.InterfacesMap[1] = inet.Interface{Name: "eth0"} - s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"), - }, - } - s.InterfacesMap[2] = inet.Interface{Name: "eth1"} - s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"), - }, - } - want := map[string]struct{}{ - "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {}, - "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {}, - } - - n := &ifinet6{stack: s} - contents := n.contents() - if len(contents) != len(want) { - t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want)) - } - got := map[string]struct{}{} - for _, l := range contents { - got[l] = struct{}{} - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Got n.contents() = %v, want = %v", got, want) - } -} - -// TestIPForwarding tests the implementation of -// /proc/sys/net/ipv4/ip_forwarding -func TestConfigureIPForwarding(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - - var cases = []struct { - comment string - initial bool - str string - final bool - }{ - { - comment: `Forwarding is disabled; write 1 and enable forwarding`, - initial: false, - str: "1", - final: true, - }, - { - comment: `Forwarding is disabled; write 0 and disable forwarding`, - initial: false, - str: "0", - final: false, - }, - { - comment: `Forwarding is enabled; write 1 and enable forwarding`, - initial: true, - str: "1", - final: true, - }, - { - comment: `Forwarding is enabled; write 0 and disable forwarding`, - initial: true, - str: "0", - final: false, - }, - { - comment: `Forwarding is disabled; write 2404 and enable forwarding`, - initial: false, - str: "2404", - final: true, - }, - { - comment: `Forwarding is enabled; write 2404 and enable forwarding`, - initial: true, - str: "2404", - final: true, - }, - } - for _, c := range cases { - t.Run(c.comment, func(t *testing.T) { - s.IPForwarding = c.initial - - file := &ipForwarding{stack: s, enabled: c.initial} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if got, want := s.IPForwarding, c.final; got != want { - t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) - } - }) - } -} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go deleted file mode 100644 index 14f806c3c..000000000 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "fmt" - "math" - "path" - "strconv" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -var ( - // Next offset 256 by convention. Adds 1 for the next offset. - selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1} - threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1} - - // /proc/[pid] next offset starts at 256+2 (files above), then adds the - // PID, and adds 1 for the next offset. - proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1} - proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1} - proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1} -) - -var ( - tasksStaticFiles = map[string]testutil.DirentType{ - "cmdline": linux.DT_REG, - "cpuinfo": linux.DT_REG, - "filesystems": linux.DT_REG, - "loadavg": linux.DT_REG, - "meminfo": linux.DT_REG, - "mounts": linux.DT_LNK, - "net": linux.DT_LNK, - "self": linux.DT_LNK, - "stat": linux.DT_REG, - "sys": linux.DT_DIR, - "thread-self": linux.DT_LNK, - "uptime": linux.DT_REG, - "version": linux.DT_REG, - } - tasksStaticFilesNextOffs = map[string]int64{ - "self": selfLink.NextOff, - "thread-self": threadSelfLink.NextOff, - } - taskStaticFiles = map[string]testutil.DirentType{ - "auxv": linux.DT_REG, - "cgroup": linux.DT_REG, - "cwd": linux.DT_LNK, - "cmdline": linux.DT_REG, - "comm": linux.DT_REG, - "environ": linux.DT_REG, - "exe": linux.DT_LNK, - "fd": linux.DT_DIR, - "fdinfo": linux.DT_DIR, - "gid_map": linux.DT_REG, - "io": linux.DT_REG, - "maps": linux.DT_REG, - "mem": linux.DT_REG, - "mountinfo": linux.DT_REG, - "mounts": linux.DT_REG, - "net": linux.DT_DIR, - "ns": linux.DT_DIR, - "oom_score": linux.DT_REG, - "oom_score_adj": linux.DT_REG, - "smaps": linux.DT_REG, - "stat": linux.DT_REG, - "statm": linux.DT_REG, - "status": linux.DT_REG, - "task": linux.DT_DIR, - "uid_map": linux.DT_REG, - } -) - -func setup(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Error creating kernel: %v", err) - } - - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - - k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.MountOptions{}) - if err != nil { - t.Fatalf("NewMountNamespace(): %v", err) - } - root := mntns.Root() - root.IncRef() - defer root.DecRef(ctx) - pop := &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse("/proc"), - } - if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil { - t.Fatalf("MkDir(/proc): %v", err) - } - - pop = &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse("/proc"), - } - mntOpts := &vfs.MountOptions{ - GetFilesystemOptions: vfs.GetFilesystemOptions{ - InternalData: &InternalData{ - Cgroups: map[string]string{ - "cpuset": "/foo/cpuset", - "memory": "/foo/memory", - }, - }, - }, - } - if _, err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil { - t.Fatalf("MountAt(/proc): %v", err) - } - return testutil.NewSystem(ctx, t, k.VFS(), mntns) -} - -func TestTasksEmpty(t *testing.T) { - s := setup(t) - defer s.Destroy() - - collector := s.ListDirents(s.PathOpAtRoot("/proc")) - s.AssertAllDirentTypes(collector, tasksStaticFiles) - s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs) -} - -func TestTasks(t *testing.T) { - s := setup(t) - defer s.Destroy() - - expectedDirents := make(map[string]testutil.DirentType) - for n, d := range tasksStaticFiles { - expectedDirents[n] = d - } - - k := kernel.KernelFromContext(s.Ctx) - var tasks []*kernel.Task - for i := 0; i < 5; i++ { - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - tasks = append(tasks, task) - expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR - } - - collector := s.ListDirents(s.PathOpAtRoot("/proc")) - s.AssertAllDirentTypes(collector, expectedDirents) - s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs) - - lastPid := 0 - dirents := collector.OrderedDirents() - doneSkippingNonTaskDirs := false - for _, d := range dirents { - pid, err := strconv.Atoi(d.Name) - if err != nil { - if !doneSkippingNonTaskDirs { - // We haven't gotten to the task dirs yet. - continue - } - t.Fatalf("Invalid process directory %q", d.Name) - } - doneSkippingNonTaskDirs = true - if lastPid > pid { - t.Errorf("pids not in order: %v", dirents) - } - found := false - for _, t := range tasks { - if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) { - found = true - } - } - if !found { - t.Errorf("Additional task ID %d listed: %v", pid, tasks) - } - // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the - // PID, and adds 1 for the next offset. - if want := int64(256 + 2 + pid + 1); d.NextOff != want { - t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d) - } - } - if !doneSkippingNonTaskDirs { - t.Fatalf("Never found any process directories.") - } - - // Test lookup. - for _, path := range []string{"/proc/1", "/proc/2"} { - fd, err := s.VFS.OpenAt( - s.Ctx, - s.Creds, - s.PathOpAtRoot(path), - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err) - } - defer fd.DecRef(s.Ctx) - buf := make([]byte, 1) - bufIOSeq := usermem.BytesIOSequence(buf) - if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) { - t.Errorf("wrong error reading directory: %v", err) - } - } - - if _, err := s.VFS.OpenAt( - s.Ctx, - s.Creds, - s.PathOpAtRoot("/proc/9999"), - &vfs.OpenOptions{}, - ); !linuxerr.Equals(linuxerr.ENOENT, err) { - t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err) - } -} - -func TestTasksOffset(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - for i := 0; i < 3; i++ { - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root); err != nil { - t.Fatalf("CreateTask(): %v", err) - } - } - - for _, tc := range []struct { - name string - offset int64 - wants map[string]vfs.Dirent - }{ - { - name: "small offset", - offset: 100, - wants: map[string]vfs.Dirent{ - "self": selfLink, - "thread-self": threadSelfLink, - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "offset at start", - offset: 256, - wants: map[string]vfs.Dirent{ - "self": selfLink, - "thread-self": threadSelfLink, - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "skip /proc/self", - offset: 257, - wants: map[string]vfs.Dirent{ - "thread-self": threadSelfLink, - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "skip symlinks", - offset: 258, - wants: map[string]vfs.Dirent{ - "1": proc1, - "2": proc2, - "3": proc3, - }, - }, - { - name: "skip first process", - offset: 260, - wants: map[string]vfs.Dirent{ - "2": proc2, - "3": proc3, - }, - }, - { - name: "last process", - offset: 261, - wants: map[string]vfs.Dirent{ - "3": proc3, - }, - }, - { - name: "after last", - offset: 262, - wants: nil, - }, - { - name: "TaskLimit+1", - offset: kernel.TasksLimit + 1, - wants: nil, - }, - { - name: "max", - offset: math.MaxInt64, - wants: nil, - }, - } { - t.Run(tc.name, func(t *testing.T) { - s := s.WithSubtest(t) - fd, err := s.VFS.OpenAt( - s.Ctx, - s.Creds, - s.PathOpAtRoot("/proc"), - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) - } - defer fd.DecRef(s.Ctx) - if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil { - t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err) - } - - var collector testutil.DirentCollector - if err := fd.IterDirents(s.Ctx, &collector); err != nil { - t.Fatalf("IterDirent(): %v", err) - } - - expectedTypes := make(map[string]testutil.DirentType) - expectedOffsets := make(map[string]int64) - for name, want := range tc.wants { - expectedTypes[name] = want.Type - if want.NextOff != 0 { - expectedOffsets[name] = want.NextOff - } - } - - collector.SkipDotsChecks(true) // We seek()ed past the dots. - s.AssertAllDirentTypes(&collector, expectedTypes) - s.AssertDirentOffsets(&collector, expectedOffsets) - }) - } -} - -func TestTask(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - _, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - - collector := s.ListDirents(s.PathOpAtRoot("/proc/1")) - s.AssertAllDirentTypes(collector, taskStaticFiles) -} - -func TestProcSelf(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - - collector := s.WithTemporaryContext(task.AsyncContext()).ListDirents(&vfs.PathOperation{ - Root: s.Root, - Start: s.Root, - Path: fspath.Parse("/proc/self/"), - FollowFinalSymlink: true, - }) - s.AssertAllDirentTypes(collector, taskStaticFiles) -} - -func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) { - t.Logf("Iterating: %s", fd.MappedName(ctx)) - - var collector testutil.DirentCollector - if err := fd.IterDirents(ctx, &collector); err != nil { - t.Fatalf("IterDirents(): %v", err) - } - if err := collector.Contains(".", linux.DT_DIR); err != nil { - t.Error(err.Error()) - } - if err := collector.Contains("..", linux.DT_DIR); err != nil { - t.Error(err.Error()) - } - - for _, d := range collector.Dirents() { - if d.Name == "." || d.Name == ".." { - continue - } - absPath := path.Join(fd.MappedName(ctx), d.Name) - if d.Type == linux.DT_LNK { - link, err := s.VFS.ReadlinkAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(absPath)}, - ) - if err != nil { - t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", absPath, err) - } else { - t.Logf("Skipping symlink: %s => %s", absPath, link) - } - continue - } - - t.Logf("Opening: %s", absPath) - child, err := s.VFS.OpenAt( - ctx, - auth.CredentialsFromContext(ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(absPath)}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Errorf("vfsfs.OpenAt(%v) failed: %v", absPath, err) - continue - } - defer child.DecRef(ctx) - stat, err := child.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Errorf("Stat(%v) failed: %v", absPath, err) - } - if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type { - t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type) - } - if d.Type == linux.DT_DIR { - // Found another dir, let's do it again! - iterateDir(ctx, t, s, child) - } - } -} - -// TestTree iterates all directories and stats every file. -func TestTree(t *testing.T) { - s := setup(t) - defer s.Destroy() - - k := kernel.KernelFromContext(s.Ctx) - - pop := &vfs.PathOperation{ - Root: s.Root, - Start: s.Root, - Path: fspath.Parse("test-file"), - } - opts := &vfs.OpenOptions{ - Flags: linux.O_RDONLY | linux.O_CREAT, - Mode: 0777, - } - file, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, opts) - if err != nil { - t.Fatalf("failed to create test file: %v", err) - } - defer file.DecRef(s.Ctx) - - var tasks []*kernel.Task - for i := 0; i < 5; i++ { - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root) - if err != nil { - t.Fatalf("CreateTask(): %v", err) - } - // Add file to populate /proc/[pid]/fd and fdinfo directories. - task.FDTable().NewFDVFS2(task.AsyncContext(), 0, file, kernel.FDFlags{}) - tasks = append(tasks, task) - } - - ctx := tasks[0].AsyncContext() - fd, err := s.VFS.OpenAt( - ctx, - auth.CredentialsFromContext(s.Ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc")}, - &vfs.OpenOptions{}, - ) - if err != nil { - t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err) - } - iterateDir(ctx, t, s, fd) - fd.DecRef(ctx) -} diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD deleted file mode 100644 index 403c6f254..000000000 --- a/pkg/sentry/fsimpl/signalfd/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "signalfd", - srcs = ["signalfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/kernel", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fsimpl/signalfd/signalfd_state_autogen.go b/pkg/sentry/fsimpl/signalfd/signalfd_state_autogen.go new file mode 100644 index 000000000..3bf27c6a6 --- /dev/null +++ b/pkg/sentry/fsimpl/signalfd/signalfd_state_autogen.go @@ -0,0 +1,51 @@ +// automatically generated by stateify. + +package signalfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (sfd *SignalFileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/signalfd.SignalFileDescription" +} + +func (sfd *SignalFileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "target", + "mask", + } +} + +func (sfd *SignalFileDescription) beforeSave() {} + +// +checklocksignore +func (sfd *SignalFileDescription) StateSave(stateSinkObject state.Sink) { + sfd.beforeSave() + stateSinkObject.Save(0, &sfd.vfsfd) + stateSinkObject.Save(1, &sfd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &sfd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &sfd.NoLockFD) + stateSinkObject.Save(4, &sfd.target) + stateSinkObject.Save(5, &sfd.mask) +} + +func (sfd *SignalFileDescription) afterLoad() {} + +// +checklocksignore +func (sfd *SignalFileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sfd.vfsfd) + stateSourceObject.Load(1, &sfd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &sfd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &sfd.NoLockFD) + stateSourceObject.Load(4, &sfd.target) + stateSourceObject.Load(5, &sfd.mask) +} + +func init() { + state.Register((*SignalFileDescription)(nil)) +} diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD deleted file mode 100644 index 9defca936..000000000 --- a/pkg/sentry/fsimpl/sockfs/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "sockfs", - srcs = ["sockfs.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - ], -) diff --git a/pkg/sentry/fsimpl/sockfs/sockfs_state_autogen.go b/pkg/sentry/fsimpl/sockfs/sockfs_state_autogen.go new file mode 100644 index 000000000..cf6eddef2 --- /dev/null +++ b/pkg/sentry/fsimpl/sockfs/sockfs_state_autogen.go @@ -0,0 +1,96 @@ +// automatically generated by stateify. + +package sockfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fsType *filesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/sockfs.filesystemType" +} + +func (fsType *filesystemType) StateFields() []string { + return []string{} +} + +func (fsType *filesystemType) beforeSave() {} + +// +checklocksignore +func (fsType *filesystemType) StateSave(stateSinkObject state.Sink) { + fsType.beforeSave() +} + +func (fsType *filesystemType) afterLoad() {} + +// +checklocksignore +func (fsType *filesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/sockfs.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (i *inode) StateTypeName() string { + return "pkg/sentry/fsimpl/sockfs.inode" +} + +func (i *inode) StateFields() []string { + return []string{ + "InodeAttrs", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + } +} + +func (i *inode) beforeSave() {} + +// +checklocksignore +func (i *inode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeAttrs) + stateSinkObject.Save(1, &i.InodeNoopRefCount) + stateSinkObject.Save(2, &i.InodeNotDirectory) + stateSinkObject.Save(3, &i.InodeNotSymlink) +} + +func (i *inode) afterLoad() {} + +// +checklocksignore +func (i *inode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeAttrs) + stateSourceObject.Load(1, &i.InodeNoopRefCount) + stateSourceObject.Load(2, &i.InodeNotDirectory) + stateSourceObject.Load(3, &i.InodeNotSymlink) +} + +func init() { + state.Register((*filesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*inode)(nil)) +} diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD deleted file mode 100644 index ab21f028e..000000000 --- a/pkg/sentry/fsimpl/sys/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dir_refs", - out = "dir_refs.go", - package = "sys", - prefix = "dir", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "dir", - }, -) - -go_library( - name = "sys", - srcs = [ - "dir_refs.go", - "kcov.go", - "sys.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/coverage", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sentry/arch", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) - -go_test( - name = "sys_test", - srcs = ["sys_test.go"], - deps = [ - ":sys", - "//pkg/abi/linux", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/sys/dir_refs.go b/pkg/sentry/fsimpl/sys/dir_refs.go new file mode 100644 index 000000000..17bc43d2e --- /dev/null +++ b/pkg/sentry/fsimpl/sys/dir_refs.go @@ -0,0 +1,140 @@ +package sys + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const direnableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var dirobj *dir + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type dirRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *dirRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *dirRefs) RefType() string { + return fmt.Sprintf("%T", dirobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *dirRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *dirRefs) LogRefs() bool { + return direnableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *dirRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *dirRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if direnableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *dirRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if direnableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *dirRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if direnableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *dirRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/sys/sys_state_autogen.go b/pkg/sentry/fsimpl/sys/sys_state_autogen.go new file mode 100644 index 000000000..c5adf7db3 --- /dev/null +++ b/pkg/sentry/fsimpl/sys/sys_state_autogen.go @@ -0,0 +1,263 @@ +// automatically generated by stateify. + +package sys + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *dirRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.dirRefs" +} + +func (r *dirRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *dirRefs) beforeSave() {} + +// +checklocksignore +func (r *dirRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *dirRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (i *kcovInode) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.kcovInode" +} + +func (i *kcovInode) StateFields() []string { + return []string{ + "InodeAttrs", + "InodeNoopRefCount", + "InodeNotDirectory", + "InodeNotSymlink", + "implStatFS", + } +} + +func (i *kcovInode) beforeSave() {} + +// +checklocksignore +func (i *kcovInode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeAttrs) + stateSinkObject.Save(1, &i.InodeNoopRefCount) + stateSinkObject.Save(2, &i.InodeNotDirectory) + stateSinkObject.Save(3, &i.InodeNotSymlink) + stateSinkObject.Save(4, &i.implStatFS) +} + +func (i *kcovInode) afterLoad() {} + +// +checklocksignore +func (i *kcovInode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeAttrs) + stateSourceObject.Load(1, &i.InodeNoopRefCount) + stateSourceObject.Load(2, &i.InodeNotDirectory) + stateSourceObject.Load(3, &i.InodeNotSymlink) + stateSourceObject.Load(4, &i.implStatFS) +} + +func (fd *kcovFD) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.kcovFD" +} + +func (fd *kcovFD) StateFields() []string { + return []string{ + "FileDescriptionDefaultImpl", + "NoLockFD", + "vfsfd", + "inode", + "kcov", + } +} + +func (fd *kcovFD) beforeSave() {} + +// +checklocksignore +func (fd *kcovFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(1, &fd.NoLockFD) + stateSinkObject.Save(2, &fd.vfsfd) + stateSinkObject.Save(3, &fd.inode) + stateSinkObject.Save(4, &fd.kcov) +} + +func (fd *kcovFD) afterLoad() {} + +// +checklocksignore +func (fd *kcovFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(1, &fd.NoLockFD) + stateSourceObject.Load(2, &fd.vfsfd) + stateSourceObject.Load(3, &fd.inode) + stateSourceObject.Load(4, &fd.kcov) +} + +func (fsType *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.FilesystemType" +} + +func (fsType *FilesystemType) StateFields() []string { + return []string{} +} + +func (fsType *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fsType *FilesystemType) StateSave(stateSinkObject state.Sink) { + fsType.beforeSave() +} + +func (fsType *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fsType *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "Filesystem", + "devMinor", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.Filesystem) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.Filesystem) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (d *dir) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.dir" +} + +func (d *dir) StateFields() []string { + return []string{ + "dirRefs", + "InodeAlwaysValid", + "InodeAttrs", + "InodeNotSymlink", + "InodeDirectoryNoNewChildren", + "InodeTemporary", + "OrderedChildren", + "locks", + } +} + +func (d *dir) beforeSave() {} + +// +checklocksignore +func (d *dir) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.dirRefs) + stateSinkObject.Save(1, &d.InodeAlwaysValid) + stateSinkObject.Save(2, &d.InodeAttrs) + stateSinkObject.Save(3, &d.InodeNotSymlink) + stateSinkObject.Save(4, &d.InodeDirectoryNoNewChildren) + stateSinkObject.Save(5, &d.InodeTemporary) + stateSinkObject.Save(6, &d.OrderedChildren) + stateSinkObject.Save(7, &d.locks) +} + +func (d *dir) afterLoad() {} + +// +checklocksignore +func (d *dir) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.dirRefs) + stateSourceObject.Load(1, &d.InodeAlwaysValid) + stateSourceObject.Load(2, &d.InodeAttrs) + stateSourceObject.Load(3, &d.InodeNotSymlink) + stateSourceObject.Load(4, &d.InodeDirectoryNoNewChildren) + stateSourceObject.Load(5, &d.InodeTemporary) + stateSourceObject.Load(6, &d.OrderedChildren) + stateSourceObject.Load(7, &d.locks) +} + +func (c *cpuFile) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.cpuFile" +} + +func (c *cpuFile) StateFields() []string { + return []string{ + "implStatFS", + "DynamicBytesFile", + "maxCores", + } +} + +func (c *cpuFile) beforeSave() {} + +// +checklocksignore +func (c *cpuFile) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.implStatFS) + stateSinkObject.Save(1, &c.DynamicBytesFile) + stateSinkObject.Save(2, &c.maxCores) +} + +func (c *cpuFile) afterLoad() {} + +// +checklocksignore +func (c *cpuFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.implStatFS) + stateSourceObject.Load(1, &c.DynamicBytesFile) + stateSourceObject.Load(2, &c.maxCores) +} + +func (i *implStatFS) StateTypeName() string { + return "pkg/sentry/fsimpl/sys.implStatFS" +} + +func (i *implStatFS) StateFields() []string { + return []string{} +} + +func (i *implStatFS) beforeSave() {} + +// +checklocksignore +func (i *implStatFS) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *implStatFS) afterLoad() {} + +// +checklocksignore +func (i *implStatFS) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*dirRefs)(nil)) + state.Register((*kcovInode)(nil)) + state.Register((*kcovFD)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*dir)(nil)) + state.Register((*cpuFile)(nil)) + state.Register((*implStatFS)(nil)) +} diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go deleted file mode 100644 index 0a0d914cc..000000000 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sys_test - -import ( - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func newTestSystem(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Failed to create test kernel: %v", err) - } - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - k.VFS().MustRegisterFilesystemType(sys.Name, sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.MountOptions{}) - if err != nil { - t.Fatalf("Failed to create new mount namespace: %v", err) - } - return testutil.NewSystem(ctx, t, k.VFS(), mns) -} - -func TestReadCPUFile(t *testing.T) { - s := newTestSystem(t) - defer s.Destroy() - k := kernel.KernelFromContext(s.Ctx) - maxCPUCores := k.ApplicationCores() - - expected := fmt.Sprintf("0-%d\n", maxCPUCores-1) - - for _, fname := range []string{"online", "possible", "present"} { - pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname)) - fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{}) - if err != nil { - t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) - } - defer fd.DecRef(s.Ctx) - content, err := s.ReadToEnd(fd) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - if diff := cmp.Diff(expected, content); diff != "" { - t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) - } - } -} - -func TestSysRootContainsExpectedEntries(t *testing.T) { - s := newTestSystem(t) - defer s.Destroy() - pop := s.PathOpAtRoot("/") - s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ - "block": linux.DT_DIR, - "bus": linux.DT_DIR, - "class": linux.DT_DIR, - "dev": linux.DT_DIR, - "devices": linux.DT_DIR, - "firmware": linux.DT_DIR, - "fs": linux.DT_DIR, - "kernel": linux.DT_DIR, - "module": linux.DT_DIR, - "power": linux.DT_DIR, - }) -} diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD deleted file mode 100644 index b3f9d1010..000000000 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "kernel.go", - "testutil.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/cpuid", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/memutil", - "//pkg/sentry/fsbridge", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/kvm", - "//pkg/sentry/platform/ptrace", - "//pkg/sentry/time", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go deleted file mode 100644 index 473b41cff..000000000 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "flag" - "fmt" - "os" - "runtime" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/fsbridge" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/sched" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/loader" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/sentry/vfs" - - // Platforms are plugable. - _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm" - _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" -) - -var ( - platformFlag = flag.String("platform", "ptrace", "specify which platform to use") -) - -// Boot initializes a new bare bones kernel for test. -func Boot() (*kernel.Kernel, error) { - platformCtr, err := platform.Lookup(*platformFlag) - if err != nil { - return nil, fmt.Errorf("platform not found: %v", err) - } - deviceFile, err := platformCtr.OpenDevice() - if err != nil { - return nil, fmt.Errorf("creating platform: %v", err) - } - plat, err := platformCtr.New(deviceFile) - if err != nil { - return nil, fmt.Errorf("creating platform: %v", err) - } - - kernel.VFS2Enabled = true - k := &kernel.Kernel{ - Platform: plat, - } - - mf, err := createMemoryFile() - if err != nil { - return nil, err - } - k.SetMemoryFile(mf) - - // Pass k as the platform since it is savable, unlike the actual platform. - vdso, err := loader.PrepareVDSO(k) - if err != nil { - return nil, fmt.Errorf("creating vdso: %v", err) - } - - // Create timekeeper. - tk := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) - tk.SetClocks(time.NewCalibratedClocks()) - - creds := auth.NewRootCredentials(auth.NewRootUserNamespace()) - - // Initiate the Kernel object, which is required by the Context passed - // to createVFS in order to mount (among other things) procfs. - if err = k.Init(kernel.InitKernelArgs{ - ApplicationCores: uint(runtime.GOMAXPROCS(-1)), - FeatureSet: cpuid.HostFeatureSet(), - Timekeeper: tk, - RootUserNamespace: creds.UserNamespace, - Vdso: vdso, - RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace), - RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace), - RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), - PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace), - }); err != nil { - return nil, fmt.Errorf("initializing kernel: %v", err) - } - - k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - AllowUserList: true, - }) - - ls, err := limits.NewLinuxLimitSet() - if err != nil { - return nil, err - } - tg := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls) - k.TestOnlySetGlobalInit(tg) - - return k, nil -} - -// CreateTask creates a new bare bones task for tests. -func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) { - k := kernel.KernelFromContext(ctx) - if k == nil { - return nil, fmt.Errorf("cannot find kernel from context") - } - - exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root) - if err != nil { - return nil, err - } - m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) - m.SetExecutable(ctx, fsbridge.NewVFSFile(exe)) - - config := &kernel.TaskConfig{ - Kernel: k, - ThreadGroup: tc, - TaskImage: &kernel.TaskImage{Name: name, MemoryManager: m}, - Credentials: auth.CredentialsFromContext(ctx), - NetworkNamespace: k.RootNetworkNamespace(), - AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), - UTSNamespace: kernel.UTSNamespaceFromContext(ctx), - IPCNamespace: kernel.IPCNamespaceFromContext(ctx), - AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), - MountNamespaceVFS2: mntns, - FSContext: kernel.NewFSContextVFS2(root, cwd, 0022), - FDTable: k.NewFDTable(), - } - t, err := k.TaskSet().NewTask(ctx, config) - if err != nil { - config.ThreadGroup.Release(ctx) - return nil, err - } - return t, nil -} - -func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) { - const name = "executable" - pop := &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(name), - } - opts := &vfs.OpenOptions{ - Flags: linux.O_RDONLY | linux.O_CREAT, - Mode: 0777, - } - return vfsObj.OpenAt(ctx, creds, pop, opts) -} - -func createMemoryFile() (*pgalloc.MemoryFile, error) { - const memfileName = "test-memory" - memfd, err := memutil.CreateMemFD(memfileName, 0) - if err != nil { - return nil, fmt.Errorf("error creating memfd: %v", err) - } - memfile := os.NewFile(uintptr(memfd), memfileName) - mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) - if err != nil { - _ = memfile.Close() - return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err) - } - return mf, nil -} diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go deleted file mode 100644 index 59e6f9c92..000000000 --- a/pkg/sentry/fsimpl/testutil/testutil.go +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides common test utilities for kernfs-based -// filesystems. -package testutil - -import ( - "fmt" - "io" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" -) - -// System represents the context for a single test. -// -// Test systems must be explicitly destroyed with System.Destroy. -type System struct { - t *testing.T - Ctx context.Context - Creds *auth.Credentials - VFS *vfs.VirtualFilesystem - Root vfs.VirtualDentry - MntNs *vfs.MountNamespace -} - -// NewSystem constructs a System. -// -// Precondition: Caller must hold a reference on mns, whose ownership -// is transferred to the new System. -func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System { - root := mns.Root() - root.IncRef() - s := &System{ - t: t, - Ctx: ctx, - Creds: auth.CredentialsFromContext(ctx), - VFS: v, - MntNs: mns, - Root: root, - } - return s -} - -// WithSubtest creates a temporary test system with a new test harness, -// referencing all other resources from the original system. This is useful when -// a system is reused for multiple subtests, and the T needs to change for each -// case. Note that this is safe when test cases run in parallel, as all -// resources referenced by the system are immutable, or handle interior -// mutations in a thread-safe manner. -// -// The returned system must not outlive the original and should not be destroyed -// via System.Destroy. -func (s *System) WithSubtest(t *testing.T) *System { - return &System{ - t: t, - Ctx: s.Ctx, - Creds: s.Creds, - VFS: s.VFS, - MntNs: s.MntNs, - Root: s.Root, - } -} - -// WithTemporaryContext constructs a temporary test system with a new context -// ctx. The temporary system borrows all resources and references from the -// original system. The returned temporary system must not outlive the original -// system, and should not be destroyed via System.Destroy. -func (s *System) WithTemporaryContext(ctx context.Context) *System { - return &System{ - t: s.t, - Ctx: ctx, - Creds: s.Creds, - VFS: s.VFS, - MntNs: s.MntNs, - Root: s.Root, - } -} - -// Destroy release resources associated with a test system. -func (s *System) Destroy() { - s.Root.DecRef(s.Ctx) - s.MntNs.DecRef(s.Ctx) // Reference on MntNs passed to NewSystem. -} - -// ReadToEnd reads the contents of fd until EOF to a string. -func (s *System) ReadToEnd(fd *vfs.FileDescription) (string, error) { - buf := make([]byte, hostarch.PageSize) - bufIOSeq := usermem.BytesIOSequence(buf) - opts := vfs.ReadOptions{} - - var content strings.Builder - for { - n, err := fd.Read(s.Ctx, bufIOSeq, opts) - if n == 0 || err != nil { - if err == io.EOF { - err = nil - } - return content.String(), err - } - content.Write(buf[:n]) - } -} - -// PathOpAtRoot constructs a PathOperation with the given path from -// the root of the filesystem. -func (s *System) PathOpAtRoot(path string) *vfs.PathOperation { - return &vfs.PathOperation{ - Root: s.Root, - Start: s.Root, - Path: fspath.Parse(path), - } -} - -// GetDentryOrDie attempts to resolve a dentry referred to by the -// provided path operation. If unsuccessful, the test fails. -func (s *System) GetDentryOrDie(pop *vfs.PathOperation) vfs.VirtualDentry { - vd, err := s.VFS.GetDentryAt(s.Ctx, s.Creds, pop, &vfs.GetDentryOptions{}) - if err != nil { - s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err) - } - return vd -} - -// DirentType is an alias for values for linux_dirent64.d_type. -type DirentType = uint8 - -// ListDirents lists the Dirents for a directory at pop. -func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector { - fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{Flags: linux.O_RDONLY}) - if err != nil { - s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) - } - defer fd.DecRef(s.Ctx) - - collector := &DirentCollector{} - if err := fd.IterDirents(s.Ctx, collector); err != nil { - s.t.Fatalf("IterDirent failed: %v", err) - } - return collector -} - -// AssertAllDirentTypes verifies that the set of dirents in collector contains -// exactly the specified set of expected entries. AssertAllDirentTypes respects -// collector.skipDots, and implicitly checks for "." and ".." accordingly. -func (s *System) AssertAllDirentTypes(collector *DirentCollector, expected map[string]DirentType) { - if expected == nil { - expected = make(map[string]DirentType) - } - // Also implicitly check for "." and "..", if enabled. - if !collector.skipDots { - expected["."] = linux.DT_DIR - expected[".."] = linux.DT_DIR - } - - dentryTypes := make(map[string]DirentType) - collector.mu.Lock() - for _, dirent := range collector.dirents { - dentryTypes[dirent.Name] = dirent.Type - } - collector.mu.Unlock() - if diff := cmp.Diff(expected, dentryTypes); diff != "" { - s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff) - } -} - -// AssertDirentOffsets verifies that collector contains at least the entries -// specified in expected, with the given NextOff field. Entries specified in -// expected but missing from collector result in failure. Extra entries in -// collector are ignored. AssertDirentOffsets respects collector.skipDots, and -// implicitly checks for "." and ".." accordingly. -func (s *System) AssertDirentOffsets(collector *DirentCollector, expected map[string]int64) { - // Also implicitly check for "." and "..", if enabled. - if !collector.skipDots { - expected["."] = 1 - expected[".."] = 2 - } - - dentryNextOffs := make(map[string]int64) - collector.mu.Lock() - for _, dirent := range collector.dirents { - // Ignore extra entries in dentries that are not in expected. - if _, ok := expected[dirent.Name]; ok { - dentryNextOffs[dirent.Name] = dirent.NextOff - } - } - collector.mu.Unlock() - if diff := cmp.Diff(expected, dentryNextOffs); diff != "" { - s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff) - } -} - -// DirentCollector provides an implementation for vfs.IterDirentsCallback for -// testing. It simply iterates to the end of a given directory FD and collects -// all dirents emitted by the callback. -type DirentCollector struct { - mu sync.Mutex - order []*vfs.Dirent - dirents map[string]*vfs.Dirent - // When the collector is used in various Assert* functions, should "." and - // ".." be implicitly checked? - skipDots bool -} - -// SkipDotsChecks enables or disables the implicit checks on "." and ".." when -// the collector is used in various Assert* functions. Note that "." and ".." -// are still collected if passed to d.Handle, so the caller should only disable -// the checks when they aren't expected. -func (d *DirentCollector) SkipDotsChecks(value bool) { - d.skipDots = value -} - -// Handle implements vfs.IterDirentsCallback.Handle. -func (d *DirentCollector) Handle(dirent vfs.Dirent) error { - d.mu.Lock() - if d.dirents == nil { - d.dirents = make(map[string]*vfs.Dirent) - } - d.order = append(d.order, &dirent) - d.dirents[dirent.Name] = &dirent - d.mu.Unlock() - return nil -} - -// Count returns the number of dirents currently in the collector. -func (d *DirentCollector) Count() int { - d.mu.Lock() - defer d.mu.Unlock() - return len(d.dirents) -} - -// Contains checks whether the collector has a dirent with the given name and -// type. -func (d *DirentCollector) Contains(name string, typ uint8) error { - d.mu.Lock() - defer d.mu.Unlock() - dirent, ok := d.dirents[name] - if !ok { - return fmt.Errorf("no dirent named %q found", name) - } - if dirent.Type != typ { - return fmt.Errorf("dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent) - } - return nil -} - -// Dirents returns all dirents discovered by this collector. -func (d *DirentCollector) Dirents() map[string]*vfs.Dirent { - d.mu.Lock() - dirents := make(map[string]*vfs.Dirent) - for n, d := range d.dirents { - dirents[n] = d - } - d.mu.Unlock() - return dirents -} - -// OrderedDirents returns an ordered list of dirents as discovered by this -// collector. -func (d *DirentCollector) OrderedDirents() []*vfs.Dirent { - d.mu.Lock() - dirents := make([]*vfs.Dirent, len(d.order)) - copy(dirents, d.order) - d.mu.Unlock() - return dirents -} diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD deleted file mode 100644 index 2b83d7d9a..000000000 --- a/pkg/sentry/fsimpl/timerfd/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "timerfd", - srcs = ["timerfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/kernel/time", - "//pkg/sentry/vfs", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fsimpl/timerfd/timerfd_state_autogen.go b/pkg/sentry/fsimpl/timerfd/timerfd_state_autogen.go new file mode 100644 index 000000000..12970f25c --- /dev/null +++ b/pkg/sentry/fsimpl/timerfd/timerfd_state_autogen.go @@ -0,0 +1,54 @@ +// automatically generated by stateify. + +package timerfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (tfd *TimerFileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/timerfd.TimerFileDescription" +} + +func (tfd *TimerFileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "events", + "timer", + "val", + } +} + +func (tfd *TimerFileDescription) beforeSave() {} + +// +checklocksignore +func (tfd *TimerFileDescription) StateSave(stateSinkObject state.Sink) { + tfd.beforeSave() + stateSinkObject.Save(0, &tfd.vfsfd) + stateSinkObject.Save(1, &tfd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &tfd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &tfd.NoLockFD) + stateSinkObject.Save(4, &tfd.events) + stateSinkObject.Save(5, &tfd.timer) + stateSinkObject.Save(6, &tfd.val) +} + +func (tfd *TimerFileDescription) afterLoad() {} + +// +checklocksignore +func (tfd *TimerFileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tfd.vfsfd) + stateSourceObject.Load(1, &tfd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &tfd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &tfd.NoLockFD) + stateSourceObject.Load(4, &tfd.events) + stateSourceObject.Load(5, &tfd.timer) + stateSourceObject.Load(6, &tfd.val) +} + +func init() { + state.Register((*TimerFileDescription)(nil)) +} diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD deleted file mode 100644 index 94486bb63..000000000 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "dentry_list", - out = "dentry_list.go", - package = "tmpfs", - prefix = "dentry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*dentry", - "Linker": "*dentry", - }, -) - -go_template_instance( - name = "fstree", - out = "fstree.go", - package = "tmpfs", - prefix = "generic", - template = "//pkg/sentry/vfs/genericfstree:generic_fstree", - types = { - "Dentry": "dentry", - }, -) - -go_template_instance( - name = "inode_refs", - out = "inode_refs.go", - package = "tmpfs", - prefix = "inode", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "inode", - }, -) - -go_library( - name = "tmpfs", - srcs = [ - "dentry_list.go", - "device_file.go", - "directory.go", - "filesystem.go", - "fstree.go", - "inode_refs.go", - "named_pipe.go", - "regular_file.go", - "save_restore.go", - "socket_file.go", - "symlink.go", - "tmpfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsmetric", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sentry/vfs/memxattr", - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "benchmark_test", - size = "small", - srcs = ["benchmark_test.go"], - deps = [ - ":tmpfs", - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/refs", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - ], -) - -go_test( - name = "tmpfs_test", - size = "small", - srcs = [ - "pipe_test.go", - "regular_file_test.go", - "stat_test.go", - "tmpfs_test.go", - ], - library = ":tmpfs", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go deleted file mode 100644 index 2c29343c1..000000000 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ /dev/null @@ -1,488 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package benchmark_test - -import ( - "fmt" - "runtime" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -// Differences from stat_benchmark: -// -// - Syscall interception, CopyInPath, copyOutStat, and overlayfs overheads are -// not included. -// -// - *MountStat benchmarks use a tmpfs root mount and a tmpfs submount at /tmp. -// Non-MountStat benchmarks use a tmpfs root mount and no submounts. -// stat_benchmark uses a varying root mount, a tmpfs submount at /tmp, and a -// subdirectory /tmp/<top_dir> (assuming TEST_TMPDIR == "/tmp"). Thus -// stat_benchmark at depth 1 does a comparable amount of work to *MountStat -// benchmarks at depth 2, and non-MountStat benchmarks at depth 3. -var depths = []int{1, 2, 3, 8, 64, 100} - -const ( - mountPointName = "tmp" - filename = "gvisor_test_temp_0_1557494568" -) - -// This is copied from syscalls/linux/sys_file.go, with the dependency on -// kernel.Task stripped out. -func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent, dirFD int32, path string, resolve bool, fn func(root *fs.Dirent, d *fs.Dirent) error) error { - var ( - d *fs.Dirent // The file. - rel *fs.Dirent // The relative directory for search (if required.) - err error - ) - - // Extract the working directory (maybe). - if len(path) > 0 && path[0] == '/' { - // Absolute path; rel can be nil. - } else if dirFD == linux.AT_FDCWD { - // Need to reference the working directory. - rel = wd - } else { - // Need to extract the given FD. - return linuxerr.EBADF - } - - // Lookup the node. - remainingTraversals := uint(linux.MaxSymlinkTraversals) - if resolve { - d, err = mntns.FindInode(ctx, root, rel, path, &remainingTraversals) - } else { - d, err = mntns.FindLink(ctx, root, rel, path, &remainingTraversals) - } - if err != nil { - return err - } - - err = fn(root, d) - d.DecRef(ctx) - return err -} - -func BenchmarkVFS1TmpfsStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - - // Create VFS. - tmpfsFS, ok := fs.FindFilesystem("tmpfs") - if !ok { - b.Fatalf("failed to find tmpfs filesystem type") - } - rootInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - mntns, err := fs.NewMountNamespace(ctx, rootInode) - if err != nil { - b.Fatalf("failed to create mount namespace: %v", err) - } - defer mntns.DecRef(ctx) - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - root := mntns.Root() - defer root.DecRef(ctx) - d := root - d.IncRef() - defer d.DecRef(ctx) - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - next, err := d.Walk(ctx, root, name) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - d.DecRef(ctx) - d = next - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - file, err := d.Inode.Create(ctx, d, filename, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0644)) - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - file.DecRef(ctx) - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - dirPath := false - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := fileOpOn(ctx, mntns, root, root, linux.AT_FDCWD, filePath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error { - if dirPath && !fs.IsDir(d.Inode.StableAttr) { - return linuxerr.ENOTDIR - } - uattr, err := d.Inode.UnstableAttr(ctx) - if err != nil { - return err - } - // Sanity check. - if uattr.Perms.User.Execute { - b.Fatalf("got wrong permissions (%0o)", uattr.Perms.LinuxMode()) - } - return nil - }) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func BenchmarkVFS2TmpfsStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - b.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - defer mntns.DecRef(ctx) - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - root := mntns.Root() - root.IncRef() - defer root.DecRef(ctx) - vd := root - vd.IncRef() - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - pop := vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(name), - } - if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - nextVD, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - vd.DecRef(ctx) - vd = nextVD - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(filename), - FollowFinalSymlink: true, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: 0644, - }) - vd.DecRef(ctx) - vd = vfs.VirtualDentry{} - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - defer fd.DecRef(ctx) - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filePath), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - // Sanity check. - if stat.Mode&^linux.S_IFMT != 0644 { - b.Fatalf("got wrong permissions (%0o)", stat.Mode) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func BenchmarkVFS1TmpfsMountStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - - // Create VFS. - tmpfsFS, ok := fs.FindFilesystem("tmpfs") - if !ok { - b.Fatalf("failed to find tmpfs filesystem type") - } - rootInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - mntns, err := fs.NewMountNamespace(ctx, rootInode) - if err != nil { - b.Fatalf("failed to create mount namespace: %v", err) - } - defer mntns.DecRef(ctx) - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create and mount the submount. - root := mntns.Root() - defer root.DecRef(ctx) - if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil { - b.Fatalf("failed to create mount point: %v", err) - } - mountPoint, err := root.Walk(ctx, root, mountPointName) - if err != nil { - b.Fatalf("failed to walk to mount point: %v", err) - } - defer mountPoint.DecRef(ctx) - submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) - if err != nil { - b.Fatalf("failed to create tmpfs submount: %v", err) - } - if err := mntns.Mount(ctx, mountPoint, submountInode); err != nil { - b.Fatalf("failed to mount tmpfs submount: %v", err) - } - filePathBuilder.WriteString(mountPointName) - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - d, err := root.Walk(ctx, root, mountPointName) - if err != nil { - b.Fatalf("failed to walk to mount root: %v", err) - } - defer d.DecRef(ctx) - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - next, err := d.Walk(ctx, root, name) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - d.DecRef(ctx) - d = next - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - file, err := d.Inode.Create(ctx, d, filename, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0644)) - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - file.DecRef(ctx) - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - dirPath := false - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := fileOpOn(ctx, mntns, root, root, linux.AT_FDCWD, filePath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent) error { - if dirPath && !fs.IsDir(d.Inode.StableAttr) { - return linuxerr.ENOTDIR - } - uattr, err := d.Inode.UnstableAttr(ctx) - if err != nil { - return err - } - // Sanity check. - if uattr.Perms.User.Execute { - b.Fatalf("got wrong permissions (%0o)", uattr.Perms.LinuxMode()) - } - return nil - }) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func BenchmarkVFS2TmpfsMountStat(b *testing.B) { - for _, depth := range depths { - b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { - ctx := contexttest.Context(b) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - b.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) - if err != nil { - b.Fatalf("failed to create tmpfs root mount: %v", err) - } - defer mntns.DecRef(ctx) - - var filePathBuilder strings.Builder - filePathBuilder.WriteByte('/') - - // Create the mount point. - root := mntns.Root() - root.IncRef() - defer root.DecRef(ctx) - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(mountPointName), - } - if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - b.Fatalf("failed to create mount point: %v", err) - } - // Save the mount point for later use. - mountPoint, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to mount point: %v", err) - } - defer mountPoint.DecRef(ctx) - // Create and mount the submount. - if _, err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { - b.Fatalf("failed to mount tmpfs submount: %v", err) - } - filePathBuilder.WriteString(mountPointName) - filePathBuilder.WriteByte('/') - - // Create nested directories with given depth. - vd, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to mount root: %v", err) - } - for i := depth; i > 0; i-- { - name := fmt.Sprintf("%d", i) - pop := vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(name), - } - if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - b.Fatalf("failed to create directory %q: %v", name, err) - } - nextVD, err := vfsObj.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) - if err != nil { - b.Fatalf("failed to walk to directory %q: %v", name, err) - } - vd.DecRef(ctx) - vd = nextVD - filePathBuilder.WriteString(name) - filePathBuilder.WriteByte('/') - } - - // Create the file that will be stat'd. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: vd, - Path: fspath.Parse(filename), - FollowFinalSymlink: true, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: 0644, - }) - vd.DecRef(ctx) - if err != nil { - b.Fatalf("failed to create file %q: %v", filename, err) - } - fd.DecRef(ctx) - filePathBuilder.WriteString(filename) - filePath := filePathBuilder.String() - - runtime.GC() - b.ResetTimer() - for i := 0; i < b.N; i++ { - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filePath), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - b.Fatalf("stat(%q) failed: %v", filePath, err) - } - // Sanity check. - if stat.Mode&^linux.S_IFMT != 0644 { - b.Fatalf("got wrong permissions (%0o)", stat.Mode) - } - } - // Don't include deferred cleanup in benchmark time. - b.StopTimer() - }) - } -} - -func init() { - // Turn off reference leak checking for a fair comparison between vfs1 and - // vfs2. - refs.SetLeakMode(refs.NoLeakChecking) -} diff --git a/pkg/sentry/fsimpl/tmpfs/dentry_list.go b/pkg/sentry/fsimpl/tmpfs/dentry_list.go new file mode 100644 index 000000000..b95dd7101 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/dentry_list.go @@ -0,0 +1,221 @@ +package tmpfs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type dentryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (dentryElementMapper) linkerFor(elem *dentry) *dentry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type dentryList struct { + head *dentry + tail *dentry +} + +// Reset resets list l to the empty state. +func (l *dentryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *dentryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *dentryList) Front() *dentry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *dentryList) Back() *dentry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *dentryList) Len() (count int) { + for e := l.Front(); e != nil; e = (dentryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *dentryList) PushFront(e *dentry) { + linker := dentryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + dentryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *dentryList) PushBack(e *dentry) { + linker := dentryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + dentryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *dentryList) PushBackList(m *dentryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + dentryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + dentryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *dentryList) InsertAfter(b, e *dentry) { + bLinker := dentryElementMapper{}.linkerFor(b) + eLinker := dentryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + dentryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *dentryList) InsertBefore(a, e *dentry) { + aLinker := dentryElementMapper{}.linkerFor(a) + eLinker := dentryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + dentryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *dentryList) Remove(e *dentry) { + linker := dentryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + dentryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + dentryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type dentryEntry struct { + next *dentry + prev *dentry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *dentryEntry) Next() *dentry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *dentryEntry) Prev() *dentry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *dentryEntry) SetNext(elem *dentry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *dentryEntry) SetPrev(elem *dentry) { + e.prev = elem +} diff --git a/pkg/sentry/fsimpl/tmpfs/fstree.go b/pkg/sentry/fsimpl/tmpfs/fstree.go new file mode 100644 index 000000000..d46351488 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/fstree.go @@ -0,0 +1,55 @@ +package tmpfs + +import ( + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is +// either d2's parent or an ancestor of d2's parent. +func genericIsAncestorDentry(d, d2 *dentry) bool { + for d2 != nil { + if d2.parent == d { + return true + } + if d2.parent == d2 { + return false + } + d2 = d2.parent + } + return false +} + +// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. +func genericParentOrSelf(d *dentry) *dentry { + if d.parent != nil { + return d.parent + } + return d +} + +// PrependPath is a generic implementation of FilesystemImpl.PrependPath(). +func genericPrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *dentry, b *fspath.Builder) error { + for { + if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { + return vfs.PrependPathAtVFSRootError{} + } + if mnt != nil && &d.vfsd == mnt.Root() { + return nil + } + if d.parent == nil { + return vfs.PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func genericDebugPathname(d *dentry) string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/fsimpl/tmpfs/inode_refs.go b/pkg/sentry/fsimpl/tmpfs/inode_refs.go new file mode 100644 index 000000000..f0f032e0c --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/inode_refs.go @@ -0,0 +1,140 @@ +package tmpfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const inodeenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var inodeobj *inode + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type inodeRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *inodeRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *inodeRefs) RefType() string { + return fmt.Sprintf("%T", inodeobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *inodeRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *inodeRefs) LogRefs() bool { + return inodeenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *inodeRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *inodeRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if inodeenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *inodeRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if inodeenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *inodeRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if inodeenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *inodeRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go deleted file mode 100644 index 99afd9817..000000000 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -const fileName = "mypipe" - -func TestSeparateFDs(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef(ctx) - - // Open the read side. This is done in a concurrently because opening - // One end the pipe blocks until the other end is opened. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - rfdchan := make(chan *vfs.FileDescription) - go func() { - openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY} - rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - rfdchan <- rfd - }() - - // Open the write side. - openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY} - wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) - } - defer wfd.DecRef(ctx) - - rfd, ok := <-rfdchan - if !ok { - t.Fatalf("failed to open pipe for reading %q", fileName) - } - defer rfd.DecRef(ctx) - - const msg = "vamos azul" - checkEmpty(ctx, t, rfd) - checkWrite(ctx, t, wfd, msg) - checkRead(ctx, t, rfd, msg) -} - -func TestNonblockingRead(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef(ctx) - - // Open the read side as nonblocking. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} - rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) - } - defer rfd.DecRef(ctx) - - // Open the write side. - openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} - wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) - } - defer wfd.DecRef(ctx) - - const msg = "geh blau" - checkEmpty(ctx, t, rfd) - checkWrite(ctx, t, wfd, msg) - checkRead(ctx, t, rfd, msg) -} - -func TestNonblockingWriteError(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef(ctx) - - // Open the write side as nonblocking, which should return ENXIO. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} - _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if !linuxerr.Equals(linuxerr.ENXIO, err) { - t.Fatalf("expected ENXIO, but got error: %v", err) - } -} - -func TestSingleFD(t *testing.T) { - ctx, creds, vfsObj, root := setup(t) - defer root.DecRef(ctx) - - // Open the pipe as readable and writable. - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - } - openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} - fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != nil { - t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) - } - defer fd.DecRef(ctx) - - const msg = "forza blu" - checkEmpty(ctx, t, fd) - checkWrite(ctx, t, fd, msg) - checkRead(ctx, t, fd, msg) -} - -// setup creates a VFS with a pipe in the root directory at path fileName. The -// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal -// upon failure. -func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) { - ctx := contexttest.Context(t) - creds := auth.CredentialsFromContext(ctx) - - // Create VFS. - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) - if err != nil { - t.Fatalf("failed to create tmpfs root mount: %v", err) - } - - // Create the pipe. - root := mntns.Root() - root.IncRef() - pop := vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - } - mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} - if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { - t.Fatalf("failed to create file %q: %v", fileName, err) - } - - // Sanity check: the file pipe exists and has the correct mode. - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(fileName), - FollowFinalSymlink: true, - }, &vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat(%q) failed: %v", fileName, err) - } - if stat.Mode&^linux.S_IFMT != 0644 { - t.Errorf("got wrong permissions (%0o)", stat.Mode) - } - if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe { - t.Errorf("got wrong file type (%0o)", stat.Mode) - } - - return ctx, creds, vfsObj, root -} - -// checkEmpty calls t.Fatal if the pipe in fd is not empty. -func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { - readData := make([]byte, 1) - dst := usermem.BytesIOSequence(readData) - bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) - if err != linuxerr.ErrWouldBlock { - t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) - } - if bytesRead != 0 { - t.Fatalf("expected to read 0 bytes, but got %d", bytesRead) - } -} - -// checkWrite calls t.Fatal if it fails to write all of msg to fd. -func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { - writeData := []byte(msg) - src := usermem.BytesIOSequence(writeData) - bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{}) - if err != nil { - t.Fatalf("error writing to pipe %q: %v", fileName, err) - } - if bytesWritten != int64(len(writeData)) { - t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten) - } -} - -// checkRead calls t.Fatal if it fails to read msg from fd. -func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { - readData := make([]byte, len(msg)) - dst := usermem.BytesIOSequence(readData) - bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) - if err != nil { - t.Fatalf("error reading from pipe %q: %v", fileName, err) - } - if bytesRead != int64(len(msg)) { - t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead) - } - if !bytes.Equal(readData, []byte(msg)) { - t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData)) - } -} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go deleted file mode 100644 index cb7711b39..000000000 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "fmt" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -// Test that we can write some data to a file and read it back.` -func TestSimpleWriteRead(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Write. - data := []byte("foobarbaz") - n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) - } - if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want { - t.Errorf("fd.Write left offset at %d, want %d", got, want) - } - - // Seek back to beginning. - if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil { - t.Fatalf("fd.Seek failed: %v", err) - } - if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want { - t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want) - } - - // Read. - buf := make([]byte, len(data)) - n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.Read failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Read got short read length %d, want %d", n, len(data)) - } - if got, want := string(buf), string(data); got != want { - t.Errorf("Read got %q want %s", got, want) - } - if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want { - t.Errorf("fd.Write left offset at %d, want %d", got, want) - } -} - -func TestPWrite(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Fill file with 1k 'a's. - data := bytes.Repeat([]byte{'a'}, 1000) - n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) - } - - // Write "gVisor is awesome" at various offsets. - buf := []byte("gVisor is awesome") - offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1} - for _, offset := range offsets { - name := fmt.Sprintf("PWrite offset=%d", offset) - t.Run(name, func(t *testing.T) { - n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{}) - if err != nil { - t.Errorf("fd.PWrite got err %v want nil", err) - } - if n != int64(len(buf)) { - t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf)) - } - - // Update data to reflect expected file contents. - if len(data) < offset+len(buf) { - data = append(data, make([]byte, (offset+len(buf))-len(data))...) - } - copy(data[offset:], buf) - - // Read the whole file and compare with data. - readBuf := make([]byte, len(data)) - n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{}) - if err != nil { - t.Fatalf("fd.PRead failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.PRead got short read length %d, want %d", n, len(data)) - } - if got, want := string(readBuf), string(data); got != want { - t.Errorf("PRead got %q want %s", got, want) - } - - }) - } -} - -func TestLocks(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - uid1 := 123 - uid2 := 456 - if err := fd.Impl().LockBSD(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { - t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) - } - if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { - t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) - } - if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), linuxerr.ErrWouldBlock; got != want { - t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) - } - if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { - t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err) - } - if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil); err != nil { - t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) - } - - if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { - t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) - } - if err := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 1, End: 2}, nil); err != nil { - t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) - } - if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.WriteLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { - t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) - } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), linuxerr.ErrWouldBlock; got != want { - t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) - } - if err := fd.Impl().UnlockPOSIX(ctx, uid1, lock.LockRange{Start: 0, End: 1}); err != nil { - t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err) - } -} - -func TestPRead(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Write 100 sequences of 'gVisor is awesome'. - data := bytes.Repeat([]byte("gVisor is awsome"), 100) - n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - if n != int64(len(data)) { - t.Errorf("fd.Write got short write length %d, want %d", n, len(data)) - } - - // Read various sizes from various offsets. - sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000} - offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1} - - for _, size := range sizes { - for _, offset := range offsets { - name := fmt.Sprintf("PRead offset=%d size=%d", offset, size) - t.Run(name, func(t *testing.T) { - var ( - wantRead []byte - wantErr error - ) - if offset < len(data) { - wantRead = data[offset:] - } else if size > 0 { - wantErr = io.EOF - } - if offset+size < len(data) { - wantRead = wantRead[:size] - } - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{}) - if err != wantErr { - t.Errorf("fd.PRead got err %v want %v", err, wantErr) - } - if n != int64(len(wantRead)) { - t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead)) - } - if got := string(buf[:n]); got != string(wantRead) { - t.Errorf("fd.PRead got %q want %q", got, string(wantRead)) - } - }) - } - } -} - -func TestTruncate(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - // Fill the file with some data. - data := bytes.Repeat([]byte("gVisor is awsome"), 100) - written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - t.Fatalf("fd.Write failed: %v", err) - } - - // Size should be same as written. - sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE} - stat, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - if got, want := int64(stat.Size), written; got != want { - t.Errorf("fd.Stat got size %d, want %d", got, want) - } - - // Truncate down. - newSize := uint64(10) - if err := fd.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_SIZE, - Size: newSize, - }, - }); err != nil { - t.Errorf("fd.Truncate failed: %v", err) - } - // Size should be updated. - statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - if got, want := statAfterTruncateDown.Size, newSize; got != want { - t.Errorf("fd.Stat got size %d, want %d", got, want) - } - // We should only read newSize worth of data. - buf := make([]byte, 1000) - if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF { - t.Fatalf("fd.PRead failed: %v", err) - } else if uint64(n) != newSize { - t.Errorf("fd.PRead got size %d, want %d", n, newSize) - } - // Mtime and Ctime should be bumped. - if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() { - t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime) - } - if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() { - t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime) - } - - // Truncate up. - newSize = 100 - if err := fd.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_SIZE, - Size: newSize, - }, - }); err != nil { - t.Errorf("fd.Truncate failed: %v", err) - } - // Size should be updated. - statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - if got, want := statAfterTruncateUp.Size, newSize; got != want { - t.Errorf("fd.Stat got size %d, want %d", got, want) - } - // We should read newSize worth of data. - buf = make([]byte, 1000) - if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF { - t.Fatalf("fd.PRead failed: %v", err) - } else if uint64(n) != newSize { - t.Errorf("fd.PRead got size %d, want %d", n, newSize) - } - // Bytes should be null after 10, since we previously truncated to 10. - for i := uint64(10); i < newSize; i++ { - if buf[i] != 0 { - t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i]) - break - } - } - // Mtime and Ctime should be bumped. - if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() { - t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime) - } - if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() { - t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime) - } - - // Truncate to the current size. - newSize = statAfterTruncateUp.Size - if err := fd.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_SIZE, - Size: newSize, - }, - }); err != nil { - t.Errorf("fd.Truncate failed: %v", err) - } - statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts) - if err != nil { - t.Fatalf("fd.Stat failed: %v", err) - } - // Mtime and Ctime should not be bumped, since operation is a noop. - if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() { - t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime) - } - if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() { - t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime) - } -} diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go deleted file mode 100644 index f7ee4aab2..000000000 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -func TestStatAfterCreate(t *testing.T) { - ctx := contexttest.Context(t) - mode := linux.FileMode(0644) - - // Run with different file types. - for _, typ := range []string{"file", "dir", "pipe"} { - t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { - var ( - fd *vfs.FileDescription - cleanup func() - err error - ) - switch typ { - case "file": - fd, cleanup, err = newFileFD(ctx, mode) - case "dir": - fd, cleanup, err = newDirFD(ctx, mode) - case "pipe": - fd, cleanup, err = newPipeFD(ctx, mode) - default: - panic(fmt.Sprintf("unknown typ %q", typ)) - } - if err != nil { - t.Fatal(err) - } - defer cleanup() - - got, err := fd.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - - // Atime, Ctime, Mtime should all be current time (non-zero). - atime, ctime, mtime := got.Atime.ToNsec(), got.Ctime.ToNsec(), got.Mtime.ToNsec() - if atime != ctime || ctime != mtime { - t.Errorf("got atime=%d ctime=%d mtime=%d, wanted equal values", atime, ctime, mtime) - } - if atime == 0 { - t.Errorf("got atime=%d, want non-zero", atime) - } - - // Btime should be 0, as it is not set by tmpfs. - if btime := got.Btime.ToNsec(); btime != 0 { - t.Errorf("got btime %d, want 0", got.Btime.ToNsec()) - } - - // Size should be 0 (except for directories, which make up a size - // of 20 per entry, including the "." and ".." entries present in - // otherwise-empty directories). - wantSize := uint64(0) - if typ == "dir" { - wantSize = 40 - } - if got.Size != wantSize { - t.Errorf("got size %d, want %d", got.Size, wantSize) - } - - // Nlink should be 1 for files, 2 for dirs. - wantNlink := uint32(1) - if typ == "dir" { - wantNlink = 2 - } - if got.Nlink != wantNlink { - t.Errorf("got nlink %d, want %d", got.Nlink, wantNlink) - } - - // UID and GID are set from context creds. - creds := auth.CredentialsFromContext(ctx) - if got.UID != uint32(creds.EffectiveKUID) { - t.Errorf("got uid %d, want %d", got.UID, uint32(creds.EffectiveKUID)) - } - if got.GID != uint32(creds.EffectiveKGID) { - t.Errorf("got gid %d, want %d", got.GID, uint32(creds.EffectiveKGID)) - } - - // Mode. - wantMode := uint16(mode) - switch typ { - case "file": - wantMode |= linux.S_IFREG - case "dir": - wantMode |= linux.S_IFDIR - case "pipe": - wantMode |= linux.S_IFIFO - default: - panic(fmt.Sprintf("unknown typ %q", typ)) - } - - if got.Mode != wantMode { - t.Errorf("got mode %x, want %x", got.Mode, wantMode) - } - - // Ino. - if got.Ino == 0 { - t.Errorf("got ino %d, want not 0", got.Ino) - } - }) - } -} - -func TestSetStatAtime(t *testing.T) { - ctx := contexttest.Context(t) - fd, cleanup, err := newFileFD(ctx, 0644) - if err != nil { - t.Fatal(err) - } - defer cleanup() - - allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL} - - // Get initial stat. - initialStat, err := fd.Stat(ctx, allStatOptions) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - - // Set atime, but without the mask. - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{ - Mask: 0, - Atime: linux.NsecToStatxTimestamp(100), - }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v", err) - } - // Atime should be unchanged. - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != initialStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime) - } - - // Set atime, this time included in the mask. - setStat := linux.Statx{ - Mask: linux.STATX_ATIME, - Atime: linux.NsecToStatxTimestamp(100), - } - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v", err) - } - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != setStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime) - } -} - -func TestSetStat(t *testing.T) { - ctx := contexttest.Context(t) - mode := linux.FileMode(0644) - - // Run with different file types. - for _, typ := range []string{"file", "dir", "pipe"} { - t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { - var ( - fd *vfs.FileDescription - cleanup func() - err error - ) - switch typ { - case "file": - fd, cleanup, err = newFileFD(ctx, mode) - case "dir": - fd, cleanup, err = newDirFD(ctx, mode) - case "pipe": - fd, cleanup, err = newPipeFD(ctx, mode) - default: - panic(fmt.Sprintf("unknown typ %q", typ)) - } - if err != nil { - t.Fatal(err) - } - defer cleanup() - - allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL} - - // Get initial stat. - initialStat, err := fd.Stat(ctx, allStatOptions) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - - // Set atime, but without the mask. - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{ - Mask: 0, - Atime: linux.NsecToStatxTimestamp(100), - }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v", err) - } - // Atime should be unchanged. - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != initialStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime) - } - - // Set atime, this time included in the mask. - setStat := linux.Statx{ - Mask: linux.STATX_ATIME, - Atime: linux.NsecToStatxTimestamp(100), - } - if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v", err) - } - if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { - t.Errorf("Stat got error: %v", err) - } else if gotStat.Atime != setStat.Atime { - t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime) - } - }) - } -} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_state_autogen.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_state_autogen.go new file mode 100644 index 000000000..6d9f09eef --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_state_autogen.go @@ -0,0 +1,593 @@ +// automatically generated by stateify. + +package tmpfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *dentryList) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.dentryList" +} + +func (l *dentryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *dentryList) beforeSave() {} + +// +checklocksignore +func (l *dentryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *dentryList) afterLoad() {} + +// +checklocksignore +func (l *dentryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *dentryEntry) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.dentryEntry" +} + +func (e *dentryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *dentryEntry) beforeSave() {} + +// +checklocksignore +func (e *dentryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *dentryEntry) afterLoad() {} + +// +checklocksignore +func (e *dentryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (d *deviceFile) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.deviceFile" +} + +func (d *deviceFile) StateFields() []string { + return []string{ + "inode", + "kind", + "major", + "minor", + } +} + +func (d *deviceFile) beforeSave() {} + +// +checklocksignore +func (d *deviceFile) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.inode) + stateSinkObject.Save(1, &d.kind) + stateSinkObject.Save(2, &d.major) + stateSinkObject.Save(3, &d.minor) +} + +func (d *deviceFile) afterLoad() {} + +// +checklocksignore +func (d *deviceFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.inode) + stateSourceObject.Load(1, &d.kind) + stateSourceObject.Load(2, &d.major) + stateSourceObject.Load(3, &d.minor) +} + +func (dir *directory) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.directory" +} + +func (dir *directory) StateFields() []string { + return []string{ + "dentry", + "inode", + "childMap", + "numChildren", + "childList", + } +} + +func (dir *directory) beforeSave() {} + +// +checklocksignore +func (dir *directory) StateSave(stateSinkObject state.Sink) { + dir.beforeSave() + stateSinkObject.Save(0, &dir.dentry) + stateSinkObject.Save(1, &dir.inode) + stateSinkObject.Save(2, &dir.childMap) + stateSinkObject.Save(3, &dir.numChildren) + stateSinkObject.Save(4, &dir.childList) +} + +func (dir *directory) afterLoad() {} + +// +checklocksignore +func (dir *directory) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &dir.dentry) + stateSourceObject.Load(1, &dir.inode) + stateSourceObject.Load(2, &dir.childMap) + stateSourceObject.Load(3, &dir.numChildren) + stateSourceObject.Load(4, &dir.childList) +} + +func (fd *directoryFD) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.directoryFD" +} + +func (fd *directoryFD) StateFields() []string { + return []string{ + "fileDescription", + "DirectoryFileDescriptionDefaultImpl", + "iter", + "off", + } +} + +func (fd *directoryFD) beforeSave() {} + +// +checklocksignore +func (fd *directoryFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.iter) + stateSinkObject.Save(3, &fd.off) +} + +func (fd *directoryFD) afterLoad() {} + +// +checklocksignore +func (fd *directoryFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.DirectoryFileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.iter) + stateSourceObject.Load(3, &fd.off) +} + +func (r *inodeRefs) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.inodeRefs" +} + +func (r *inodeRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *inodeRefs) beforeSave() {} + +// +checklocksignore +func (r *inodeRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *inodeRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (n *namedPipe) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.namedPipe" +} + +func (n *namedPipe) StateFields() []string { + return []string{ + "inode", + "pipe", + } +} + +func (n *namedPipe) beforeSave() {} + +// +checklocksignore +func (n *namedPipe) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.inode) + stateSinkObject.Save(1, &n.pipe) +} + +func (n *namedPipe) afterLoad() {} + +// +checklocksignore +func (n *namedPipe) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.inode) + stateSourceObject.Load(1, &n.pipe) +} + +func (rf *regularFile) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.regularFile" +} + +func (rf *regularFile) StateFields() []string { + return []string{ + "inode", + "memoryUsageKind", + "mappings", + "writableMappingPages", + "data", + "seals", + "size", + } +} + +func (rf *regularFile) beforeSave() {} + +// +checklocksignore +func (rf *regularFile) StateSave(stateSinkObject state.Sink) { + rf.beforeSave() + stateSinkObject.Save(0, &rf.inode) + stateSinkObject.Save(1, &rf.memoryUsageKind) + stateSinkObject.Save(2, &rf.mappings) + stateSinkObject.Save(3, &rf.writableMappingPages) + stateSinkObject.Save(4, &rf.data) + stateSinkObject.Save(5, &rf.seals) + stateSinkObject.Save(6, &rf.size) +} + +// +checklocksignore +func (rf *regularFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rf.inode) + stateSourceObject.Load(1, &rf.memoryUsageKind) + stateSourceObject.Load(2, &rf.mappings) + stateSourceObject.Load(3, &rf.writableMappingPages) + stateSourceObject.Load(4, &rf.data) + stateSourceObject.Load(5, &rf.seals) + stateSourceObject.Load(6, &rf.size) + stateSourceObject.AfterLoad(rf.afterLoad) +} + +func (fd *regularFileFD) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.regularFileFD" +} + +func (fd *regularFileFD) StateFields() []string { + return []string{ + "fileDescription", + "off", + } +} + +func (fd *regularFileFD) beforeSave() {} + +// +checklocksignore +func (fd *regularFileFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.fileDescription) + stateSinkObject.Save(1, &fd.off) +} + +func (fd *regularFileFD) afterLoad() {} + +// +checklocksignore +func (fd *regularFileFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.fileDescription) + stateSourceObject.Load(1, &fd.off) +} + +func (s *socketFile) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.socketFile" +} + +func (s *socketFile) StateFields() []string { + return []string{ + "inode", + "ep", + } +} + +func (s *socketFile) beforeSave() {} + +// +checklocksignore +func (s *socketFile) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.inode) + stateSinkObject.Save(1, &s.ep) +} + +func (s *socketFile) afterLoad() {} + +// +checklocksignore +func (s *socketFile) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.inode) + stateSourceObject.Load(1, &s.ep) +} + +func (s *symlink) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.symlink" +} + +func (s *symlink) StateFields() []string { + return []string{ + "inode", + "target", + } +} + +func (s *symlink) beforeSave() {} + +// +checklocksignore +func (s *symlink) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.inode) + stateSinkObject.Save(1, &s.target) +} + +func (s *symlink) afterLoad() {} + +// +checklocksignore +func (s *symlink) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.inode) + stateSourceObject.Load(1, &s.target) +} + +func (fstype *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.FilesystemType" +} + +func (fstype *FilesystemType) StateFields() []string { + return []string{} +} + +func (fstype *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fstype *FilesystemType) StateSave(stateSinkObject state.Sink) { + fstype.beforeSave() +} + +func (fstype *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fstype *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "vfsfs", + "mfp", + "clock", + "devMinor", + "mopts", + "nextInoMinusOne", + "root", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.vfsfs) + stateSinkObject.Save(1, &fs.mfp) + stateSinkObject.Save(2, &fs.clock) + stateSinkObject.Save(3, &fs.devMinor) + stateSinkObject.Save(4, &fs.mopts) + stateSinkObject.Save(5, &fs.nextInoMinusOne) + stateSinkObject.Save(6, &fs.root) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.vfsfs) + stateSourceObject.Load(1, &fs.mfp) + stateSourceObject.Load(2, &fs.clock) + stateSourceObject.Load(3, &fs.devMinor) + stateSourceObject.Load(4, &fs.mopts) + stateSourceObject.Load(5, &fs.nextInoMinusOne) + stateSourceObject.Load(6, &fs.root) +} + +func (f *FilesystemOpts) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.FilesystemOpts" +} + +func (f *FilesystemOpts) StateFields() []string { + return []string{ + "RootFileType", + "RootSymlinkTarget", + "FilesystemType", + } +} + +func (f *FilesystemOpts) beforeSave() {} + +// +checklocksignore +func (f *FilesystemOpts) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.RootFileType) + stateSinkObject.Save(1, &f.RootSymlinkTarget) + stateSinkObject.Save(2, &f.FilesystemType) +} + +func (f *FilesystemOpts) afterLoad() {} + +// +checklocksignore +func (f *FilesystemOpts) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.RootFileType) + stateSourceObject.Load(1, &f.RootSymlinkTarget) + stateSourceObject.Load(2, &f.FilesystemType) +} + +func (d *dentry) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.dentry" +} + +func (d *dentry) StateFields() []string { + return []string{ + "vfsd", + "parent", + "name", + "dentryEntry", + "inode", + } +} + +func (d *dentry) beforeSave() {} + +// +checklocksignore +func (d *dentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.vfsd) + stateSinkObject.Save(1, &d.parent) + stateSinkObject.Save(2, &d.name) + stateSinkObject.Save(3, &d.dentryEntry) + stateSinkObject.Save(4, &d.inode) +} + +func (d *dentry) afterLoad() {} + +// +checklocksignore +func (d *dentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.vfsd) + stateSourceObject.Load(1, &d.parent) + stateSourceObject.Load(2, &d.name) + stateSourceObject.Load(3, &d.dentryEntry) + stateSourceObject.Load(4, &d.inode) +} + +func (i *inode) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.inode" +} + +func (i *inode) StateFields() []string { + return []string{ + "fs", + "refs", + "xattrs", + "mode", + "nlink", + "uid", + "gid", + "ino", + "atime", + "ctime", + "mtime", + "locks", + "watches", + "impl", + } +} + +func (i *inode) beforeSave() {} + +// +checklocksignore +func (i *inode) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.fs) + stateSinkObject.Save(1, &i.refs) + stateSinkObject.Save(2, &i.xattrs) + stateSinkObject.Save(3, &i.mode) + stateSinkObject.Save(4, &i.nlink) + stateSinkObject.Save(5, &i.uid) + stateSinkObject.Save(6, &i.gid) + stateSinkObject.Save(7, &i.ino) + stateSinkObject.Save(8, &i.atime) + stateSinkObject.Save(9, &i.ctime) + stateSinkObject.Save(10, &i.mtime) + stateSinkObject.Save(11, &i.locks) + stateSinkObject.Save(12, &i.watches) + stateSinkObject.Save(13, &i.impl) +} + +func (i *inode) afterLoad() {} + +// +checklocksignore +func (i *inode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.fs) + stateSourceObject.Load(1, &i.refs) + stateSourceObject.Load(2, &i.xattrs) + stateSourceObject.Load(3, &i.mode) + stateSourceObject.Load(4, &i.nlink) + stateSourceObject.Load(5, &i.uid) + stateSourceObject.Load(6, &i.gid) + stateSourceObject.Load(7, &i.ino) + stateSourceObject.Load(8, &i.atime) + stateSourceObject.Load(9, &i.ctime) + stateSourceObject.Load(10, &i.mtime) + stateSourceObject.Load(11, &i.locks) + stateSourceObject.Load(12, &i.watches) + stateSourceObject.Load(13, &i.impl) +} + +func (fd *fileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/tmpfs.fileDescription" +} + +func (fd *fileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + } +} + +func (fd *fileDescription) beforeSave() {} + +// +checklocksignore +func (fd *fileDescription) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.LockFD) +} + +func (fd *fileDescription) afterLoad() {} + +// +checklocksignore +func (fd *fileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.LockFD) +} + +func init() { + state.Register((*dentryList)(nil)) + state.Register((*dentryEntry)(nil)) + state.Register((*deviceFile)(nil)) + state.Register((*directory)(nil)) + state.Register((*directoryFD)(nil)) + state.Register((*inodeRefs)(nil)) + state.Register((*namedPipe)(nil)) + state.Register((*regularFile)(nil)) + state.Register((*regularFileFD)(nil)) + state.Register((*socketFile)(nil)) + state.Register((*symlink)(nil)) + state.Register((*FilesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*FilesystemOpts)(nil)) + state.Register((*dentry)(nil)) + state.Register((*inode)(nil)) + state.Register((*fileDescription)(nil)) +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go deleted file mode 100644 index fc5323abc..000000000 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "fmt" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -// nextFileID is used to generate unique file names. -var nextFileID int64 - -// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error -// is not nil, then cleanup should be called when the root is no longer needed. -func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { - creds := auth.CredentialsFromContext(ctx) - - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) - } - - vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) - if err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) - } - root := mntns.Root() - root.IncRef() - return vfsObj, root, func() { - root.DecRef(ctx) - mntns.DecRef(ctx) - }, nil -} - -// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If -// the returned err is not nil, then cleanup should be called when the FD is no -// longer needed. -func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the file that will be write/read. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: linux.ModeRegular | mode, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err) - } - - return fd, cleanup, nil -} - -// newDirFD is like newFileFD, but for directories. -func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the dir. - if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(dirname), - }, &vfs.MkdirOptions{ - Mode: linux.ModeDirectory | mode, - }); err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err) - } - - // Open the dir and return it. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(dirname), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY | linux.O_DIRECTORY, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err) - } - - return fd, cleanup, nil -} - -// newPipeFD is like newFileFD, but for pipes. -func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - name := fmt.Sprintf("tmpfs-test-%d", atomic.AddInt64(&nextFileID, 1)) - - if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(name), - }, &vfs.MknodOptions{ - Mode: linux.ModeNamedPipe | mode, - }); err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create pipe %q: %v", name, err) - } - - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(name), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to open pipe %q: %v", name, err) - } - - return fd, cleanup, nil -} diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD deleted file mode 100644 index 5955948f0..000000000 --- a/pkg/sentry/fsimpl/verity/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "verity", - srcs = [ - "filesystem.go", - "save_restore.go", - "verity.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/marshal/primitive", - "//pkg/merkletree", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "verity_test", - srcs = [ - "verity_test.go", - ], - library = ":verity", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/sentry/arch", - "//pkg/sentry/fsimpl/testutil", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/fsimpl/verity/verity_state_autogen.go b/pkg/sentry/fsimpl/verity/verity_state_autogen.go new file mode 100644 index 000000000..dba01a68e --- /dev/null +++ b/pkg/sentry/fsimpl/verity/verity_state_autogen.go @@ -0,0 +1,240 @@ +// automatically generated by stateify. + +package verity + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fstype *FilesystemType) StateTypeName() string { + return "pkg/sentry/fsimpl/verity.FilesystemType" +} + +func (fstype *FilesystemType) StateFields() []string { + return []string{} +} + +func (fstype *FilesystemType) beforeSave() {} + +// +checklocksignore +func (fstype *FilesystemType) StateSave(stateSinkObject state.Sink) { + fstype.beforeSave() +} + +func (fstype *FilesystemType) afterLoad() {} + +// +checklocksignore +func (fstype *FilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *filesystem) StateTypeName() string { + return "pkg/sentry/fsimpl/verity.filesystem" +} + +func (fs *filesystem) StateFields() []string { + return []string{ + "vfsfs", + "creds", + "allowRuntimeEnable", + "lowerMount", + "rootDentry", + "alg", + "action", + "opts", + } +} + +func (fs *filesystem) beforeSave() {} + +// +checklocksignore +func (fs *filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.vfsfs) + stateSinkObject.Save(1, &fs.creds) + stateSinkObject.Save(2, &fs.allowRuntimeEnable) + stateSinkObject.Save(3, &fs.lowerMount) + stateSinkObject.Save(4, &fs.rootDentry) + stateSinkObject.Save(5, &fs.alg) + stateSinkObject.Save(6, &fs.action) + stateSinkObject.Save(7, &fs.opts) +} + +func (fs *filesystem) afterLoad() {} + +// +checklocksignore +func (fs *filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.vfsfs) + stateSourceObject.Load(1, &fs.creds) + stateSourceObject.Load(2, &fs.allowRuntimeEnable) + stateSourceObject.Load(3, &fs.lowerMount) + stateSourceObject.Load(4, &fs.rootDentry) + stateSourceObject.Load(5, &fs.alg) + stateSourceObject.Load(6, &fs.action) + stateSourceObject.Load(7, &fs.opts) +} + +func (i *InternalFilesystemOptions) StateTypeName() string { + return "pkg/sentry/fsimpl/verity.InternalFilesystemOptions" +} + +func (i *InternalFilesystemOptions) StateFields() []string { + return []string{ + "LowerName", + "Alg", + "AllowRuntimeEnable", + "LowerGetFSOptions", + "Action", + } +} + +func (i *InternalFilesystemOptions) beforeSave() {} + +// +checklocksignore +func (i *InternalFilesystemOptions) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.LowerName) + stateSinkObject.Save(1, &i.Alg) + stateSinkObject.Save(2, &i.AllowRuntimeEnable) + stateSinkObject.Save(3, &i.LowerGetFSOptions) + stateSinkObject.Save(4, &i.Action) +} + +func (i *InternalFilesystemOptions) afterLoad() {} + +// +checklocksignore +func (i *InternalFilesystemOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.LowerName) + stateSourceObject.Load(1, &i.Alg) + stateSourceObject.Load(2, &i.AllowRuntimeEnable) + stateSourceObject.Load(3, &i.LowerGetFSOptions) + stateSourceObject.Load(4, &i.Action) +} + +func (d *dentry) StateTypeName() string { + return "pkg/sentry/fsimpl/verity.dentry" +} + +func (d *dentry) StateFields() []string { + return []string{ + "vfsd", + "refs", + "fs", + "mode", + "uid", + "gid", + "size", + "parent", + "name", + "children", + "childrenNames", + "childrenList", + "lowerVD", + "lowerMerkleVD", + "symlinkTarget", + "hash", + } +} + +func (d *dentry) beforeSave() {} + +// +checklocksignore +func (d *dentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.vfsd) + stateSinkObject.Save(1, &d.refs) + stateSinkObject.Save(2, &d.fs) + stateSinkObject.Save(3, &d.mode) + stateSinkObject.Save(4, &d.uid) + stateSinkObject.Save(5, &d.gid) + stateSinkObject.Save(6, &d.size) + stateSinkObject.Save(7, &d.parent) + stateSinkObject.Save(8, &d.name) + stateSinkObject.Save(9, &d.children) + stateSinkObject.Save(10, &d.childrenNames) + stateSinkObject.Save(11, &d.childrenList) + stateSinkObject.Save(12, &d.lowerVD) + stateSinkObject.Save(13, &d.lowerMerkleVD) + stateSinkObject.Save(14, &d.symlinkTarget) + stateSinkObject.Save(15, &d.hash) +} + +// +checklocksignore +func (d *dentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.vfsd) + stateSourceObject.Load(1, &d.refs) + stateSourceObject.Load(2, &d.fs) + stateSourceObject.Load(3, &d.mode) + stateSourceObject.Load(4, &d.uid) + stateSourceObject.Load(5, &d.gid) + stateSourceObject.Load(6, &d.size) + stateSourceObject.Load(7, &d.parent) + stateSourceObject.Load(8, &d.name) + stateSourceObject.Load(9, &d.children) + stateSourceObject.Load(10, &d.childrenNames) + stateSourceObject.Load(11, &d.childrenList) + stateSourceObject.Load(12, &d.lowerVD) + stateSourceObject.Load(13, &d.lowerMerkleVD) + stateSourceObject.Load(14, &d.symlinkTarget) + stateSourceObject.Load(15, &d.hash) + stateSourceObject.AfterLoad(d.afterLoad) +} + +func (fd *fileDescription) StateTypeName() string { + return "pkg/sentry/fsimpl/verity.fileDescription" +} + +func (fd *fileDescription) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "d", + "isDir", + "lowerFD", + "lowerMappable", + "merkleReader", + "merkleWriter", + "parentMerkleWriter", + "off", + } +} + +func (fd *fileDescription) beforeSave() {} + +// +checklocksignore +func (fd *fileDescription) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.d) + stateSinkObject.Save(3, &fd.isDir) + stateSinkObject.Save(4, &fd.lowerFD) + stateSinkObject.Save(5, &fd.lowerMappable) + stateSinkObject.Save(6, &fd.merkleReader) + stateSinkObject.Save(7, &fd.merkleWriter) + stateSinkObject.Save(8, &fd.parentMerkleWriter) + stateSinkObject.Save(9, &fd.off) +} + +func (fd *fileDescription) afterLoad() {} + +// +checklocksignore +func (fd *fileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.d) + stateSourceObject.Load(3, &fd.isDir) + stateSourceObject.Load(4, &fd.lowerFD) + stateSourceObject.Load(5, &fd.lowerMappable) + stateSourceObject.Load(6, &fd.merkleReader) + stateSourceObject.Load(7, &fd.merkleWriter) + stateSourceObject.Load(8, &fd.parentMerkleWriter) + stateSourceObject.Load(9, &fd.off) +} + +func init() { + state.Register((*FilesystemType)(nil)) + state.Register((*filesystem)(nil)) + state.Register((*InternalFilesystemOptions)(nil)) + state.Register((*dentry)(nil)) + state.Register((*fileDescription)(nil)) +} diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go deleted file mode 100644 index af041bd50..000000000 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ /dev/null @@ -1,1211 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package verity - -import ( - "fmt" - "io" - "math/rand" - "strconv" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - // rootMerkleFilename is the name of the root Merkle tree file. - rootMerkleFilename = "root.verity" - // maxDataSize is the maximum data size of a test file. - maxDataSize = 100000 -) - -var hashAlgs = []HashAlgorithm{SHA256, SHA512} - -func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry { - t.Helper() - d, ok := vd.Dentry().Impl().(*dentry) - if !ok { - t.Fatalf("can't assert %T as a *dentry", vd) - } - return d -} - -// dentryFromFD returns the dentry corresponding to fd. -func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { - t.Helper() - f, ok := fd.Impl().(*fileDescription) - if !ok { - t.Fatalf("can't assert %T as a *fileDescription", fd) - } - return f.d -} - -// newVerityRoot creates a new verity mount, and returns the root. The -// underlying file system is tmpfs. If the error is not nil, then cleanup -// should be called when the root is no longer needed. -func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, context.Context, error) { - t.Helper() - k, err := testutil.Boot() - if err != nil { - t.Fatalf("testutil.Boot: %v", err) - } - - ctx := k.SupervisorContext() - - rand.Seed(time.Now().UnixNano()) - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) - } - - vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - - data := "root_name=" + rootMerkleFilename - mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ - GetFilesystemOptions: vfs.GetFilesystemOptions{ - Data: data, - InternalData: InternalFilesystemOptions{ - LowerName: "tmpfs", - Alg: hashAlg, - AllowRuntimeEnable: true, - Action: ErrorOnViolation, - }, - }, - }) - if err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) - } - root := mntns.Root() - root.IncRef() - - // Use lowerRoot in the task as we modify the lower file system - // directly in many tests. - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD - tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot) - if err != nil { - t.Fatalf("testutil.CreateTask: %v", err) - } - - t.Cleanup(func() { - root.DecRef(ctx) - mntns.DecRef(ctx) - }) - return vfsObj, root, task.AsyncContext(), nil -} - -// openVerityAt opens a verity file. -// -// TODO(chongc): release reference from opening the file when done. -func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { - return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: vd, - Start: vd, - Path: fspath.Parse(path), - }, &vfs.OpenOptions{ - Flags: flags, - Mode: mode, - }) -} - -// openLowerAt opens the file in the underlying file system. -// -// TODO(chongc): release reference from opening the file when done. -func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { - return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(path), - }, &vfs.OpenOptions{ - Flags: flags, - Mode: mode, - }) -} - -// openLowerMerkleAt opens the Merkle file in the underlying file system. -// -// TODO(chongc): release reference from opening the file when done. -func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { - return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerMerkleVD, - Start: d.lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: flags, - Mode: mode, - }) -} - -// mkdirLowerAt creates a directory in the underlying file system. -func (d *dentry) mkdirLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, mode linux.FileMode) error { - return vfsObj.MkdirAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(path), - }, &vfs.MkdirOptions{ - Mode: mode, - }) -} - -// unlinkLowerAt deletes the file in the underlying file system. -func (d *dentry) unlinkLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error { - return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(path), - }) -} - -// unlinkLowerMerkleAt deletes the Merkle file in the underlying file system. -func (d *dentry) unlinkLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error { - return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(merklePrefix + path), - }) -} - -// renameLowerAt renames file name to newName in the underlying file system. -func (d *dentry) renameLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error { - return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(name), - }, &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(newName), - }, &vfs.RenameOptions{}) -} - -// renameLowerMerkleAt renames Merkle file name to newName in the underlying -// file system. -func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error { - return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(merklePrefix + name), - }, &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(merklePrefix + newName), - }, &vfs.RenameOptions{}) -} - -// symlinkLowerAt creates a symbolic link at symlink referring to the given target -// in the underlying filesystem. -func (d *dentry) symlinkLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, target, symlink string) error { - return vfsObj.SymlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: d.lowerVD, - Start: d.lowerVD, - Path: fspath.Parse(symlink), - }, target) -} - -// newFileFD creates a new file in the verity mount, and returns the FD. The FD -// points to a file that has random data generated. -func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { - // Create the file in the underlying file system. - lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) - if err != nil { - return nil, 0, err - } - - // Generate random data to be written to the file. - dataSize := rand.Intn(maxDataSize) + 1 - data := make([]byte, dataSize) - rand.Read(data) - - // Write directly to the underlying FD, since verity FD is read-only. - n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) - if err != nil { - return nil, 0, err - } - - if n != int64(len(data)) { - return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data)) - } - - lowerFD.DecRef(ctx) - - // Now open the verity file descriptor. - fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) - return fd, dataSize, err -} - -// newDirFD creates a new directory in the verity mount, and returns the FD. -func newDirFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, dirPath string, mode linux.FileMode) (*vfs.FileDescription, error) { - // Create the directory in the underlying file system. - if err := dentryFromVD(t, root).mkdirLowerAt(ctx, vfsObj, dirPath, linux.ModeRegular|mode); err != nil { - return nil, err - } - if _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, dirPath, linux.O_RDONLY|linux.O_DIRECTORY, linux.ModeRegular|mode); err != nil { - return nil, err - } - return openVerityAt(ctx, vfsObj, root, dirPath, linux.O_RDONLY|linux.O_DIRECTORY, mode) -} - -// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD. -func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) { - // Create the file in the underlying file system. - _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) - if err != nil { - return nil, err - } - // Now open the verity file descriptor. - fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) - return fd, err -} - -// flipRandomBit randomly flips a bit in the file represented by fd. -func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { - randomPos := int64(rand.Intn(size)) - byteToModify := make([]byte, 1) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { - return fmt.Errorf("lowerFD.PRead: %v", err) - } - byteToModify[0] ^= 1 - if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil { - return fmt.Errorf("lowerFD.PWrite: %v", err) - } - return nil -} - -func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { - t.Helper() - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("enable verity: %v", err) - } -} - -// TestOpen ensures that when a file is created, the corresponding Merkle tree -// file and the root Merkle tree file exist. -func TestOpen(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Ensure that the corresponding Merkle tree file is created. - if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { - t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) - } - - // Ensure the root merkle tree file is created. - if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { - t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) - } - } -} - -// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity -// file succeeds after enabling verity for it. -func TestPReadUnmodifiedFileSucceeds(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - enableVerity(ctx, t, fd) - - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.PRead: %v", err) - } - - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) - } - } -} - -// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity -// file succeeds after enabling verity for it. -func TestReadUnmodifiedFileSucceeds(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - enableVerity(ctx, t, fd) - - buf := make([]byte, size) - n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.Read: %v", err) - } - - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) - } - } -} - -// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity -// file succeeds after enabling verity for it. -func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-empty-file" - fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newEmptyFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - enableVerity(ctx, t, fd) - - var buf []byte - n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.Read: %v", err) - } - - if n != 0 { - t.Errorf("fd.Read got read length %d, expected 0", n) - } - } -} - -// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file -// succeeds after enabling verity for it. -func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms a normal read succeeds. - enableVerity(ctx, t, fd) - - // Ensure reopening the verity enabled file succeeds. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil { - t.Errorf("reopen enabled file failed: %v", err) - } - } -} - -// TestOpenNonexistentFile ensures that opening a nonexistent file does not -// trigger verification failure, even if the parent directory is verified. -func TestOpenNonexistentFile(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t, SHA256) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms a normal read succeeds. - enableVerity(ctx, t, fd) - - // Enable verity on the parent directory. - parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - enableVerity(ctx, t, parentFD) - - // Ensure open an unexpected file in the parent directory fails with - // ENOENT rather than verification failure. - if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.ENOENT, err) { - t.Errorf("OpenAt unexpected error: %v", err) - } -} - -// TestPReadModifiedFileFails ensures that read from a modified verity file -// fails. -func TestPReadModifiedFileFails(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - // Open a new lowerFD that's read/writable. - lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if err := flipRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("flipRandomBit: %v", err) - } - - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded, expected failure") - } - } -} - -// TestReadModifiedFileFails ensures that read from a modified verity file -// fails. -func TestReadModifiedFileFails(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - // Open a new lowerFD that's read/writable. - lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if err := flipRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("flipRandomBit: %v", err) - } - - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.Read succeeded, expected failure") - } - } -} - -// TestModifiedMerkleFails ensures that read from a verity file fails if the -// corresponding Merkle tree file is modified. -func TestModifiedMerkleFails(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - // Open a new lowerMerkleFD that's read/writable. - lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the Merkle tree file. - stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Errorf("lowerMerkleFD.Stat: %v", err) - } - - if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { - t.Fatalf("flipRandomBit: %v", err) - } - - // Confirm that read from a file with modified Merkle tree fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded with modified Merkle file") - } - } -} - -// TestModifiedParentMerkleFails ensures that open a verity enabled file in a -// verity enabled directory fails if the hashes related to the target file in -// the parent Merkle tree file is modified. -func TestModifiedParentMerkleFails(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - // Enable verity on the parent directory. - parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - enableVerity(ctx, t, parentFD) - - // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the parent Merkle tree file. - // This parent directory contains only one child, so any random - // modification in the parent Merkle tree should cause verification - // failure when opening the child file. - sizeString, err := parentLowerMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ - Name: childrenOffsetXattr, - Size: sizeOfStringInt32, - }) - if err != nil { - t.Fatalf("parentLowerMerkleFD.GetXattr: %v", err) - } - parentMerkleSize, err := strconv.Atoi(sizeString) - if err != nil { - t.Fatalf("Failed convert size to int: %v", err) - } - if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("flipRandomBit: %v", err) - } - - parentLowerMerkleFD.DecRef(ctx) - - // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err == nil { - t.Errorf("OpenAt file with modified parent Merkle succeeded") - } - } -} - -// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file -// succeeds after enabling verity for it. -func TestUnmodifiedStatSucceeds(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm that stat succeeds. - enableVerity(ctx, t, fd) - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { - t.Errorf("fd.Stat: %v", err) - } - } -} - -// TestModifiedStatFails checks that getting stat for a file with modified stat -// should fail. -func TestModifiedStatFails(t *testing.T) { - for _, alg := range hashAlgs { - vfsObj, root, ctx, err := newVerityRoot(t, alg) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - lowerFD := fd.Impl().(*fileDescription).lowerFD - // Change the stat of the underlying file, and check that stat fails. - if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: uint32(linux.STATX_MODE), - Mode: 0777, - }, - }); err != nil { - t.Fatalf("lowerFD.SetStat: %v", err) - } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { - t.Errorf("fd.Stat succeeded when it should fail") - } - } -} - -// TestOpenDeletedFileFails ensures that opening a deleted verity enabled file -// and/or the corresponding Merkle tree file fails with the verity error. -func TestOpenDeletedFileFails(t *testing.T) { - testCases := []struct { - name string - // The original file is removed if changeFile is true. - changeFile bool - // The Merkle tree file is removed if changeMerkleFile is true. - changeMerkleFile bool - }{ - { - name: "FileOnly", - changeFile: true, - changeMerkleFile: false, - }, - { - name: "MerkleOnly", - changeFile: false, - changeMerkleFile: true, - }, - { - name: "FileAndMerkle", - changeFile: true, - changeMerkleFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t, SHA256) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - if tc.changeFile { - if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - if tc.changeMerkleFile { - if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - - // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("got OpenAt error: %v, expected EIO", err) - } - }) - } -} - -// TestOpenRenamedFileFails ensures that opening a renamed verity enabled file -// and/or the corresponding Merkle tree file fails with the verity error. -func TestOpenRenamedFileFails(t *testing.T) { - testCases := []struct { - name string - // The original file is renamed if changeFile is true. - changeFile bool - // The Merkle tree file is renamed if changeMerkleFile is true. - changeMerkleFile bool - }{ - { - name: "FileOnly", - changeFile: true, - changeMerkleFile: false, - }, - { - name: "MerkleOnly", - changeFile: false, - changeMerkleFile: true, - }, - { - name: "FileAndMerkle", - changeFile: true, - changeMerkleFile: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t, SHA256) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - enableVerity(ctx, t, fd) - - newFilename := "renamed-test-file" - if tc.changeFile { - if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { - t.Fatalf("RenameAt: %v", err) - } - } - if tc.changeMerkleFile { - if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - - // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("got OpenAt error: %v, expected EIO", err) - } - }) - } -} - -// TestUnmodifiedSymlinkFileReadSucceeds ensures that readlink() for an -// unmodified verity enabled symlink succeeds. -func TestUnmodifiedSymlinkFileReadSucceeds(t *testing.T) { - testCases := []struct { - name string - // The symlink target is a directory. - hasDirectoryTarget bool - // The symlink target is a directory and contains a regular file which will be - // used to test walking a symlink. - testWalk bool - }{ - { - name: "RegularFileTarget", - hasDirectoryTarget: false, - testWalk: false, - }, - { - name: "DirectoryTarget", - hasDirectoryTarget: true, - testWalk: false, - }, - { - name: "RegularFileInSymlinkDirectory", - hasDirectoryTarget: true, - testWalk: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if tc.testWalk && !tc.hasDirectoryTarget { - t.Fatalf("Invalid test case: hasDirectoryTarget can't be false when testing symlink walk") - } - - vfsObj, root, ctx, err := newVerityRoot(t, SHA256) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - var target string - if tc.hasDirectoryTarget { - target = "verity-test-dir" - if _, err := newDirFD(ctx, t, vfsObj, root, target, 0644); err != nil { - t.Fatalf("newDirFD: %v", err) - } - } else { - target = "verity-test-file" - if _, _, err := newFileFD(ctx, t, vfsObj, root, target, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - if tc.testWalk { - fileInTargetDirectory := target + "/" + "verity-test-file" - if _, _, err := newFileFD(ctx, t, vfsObj, root, fileInTargetDirectory, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - symlink := "verity-test-symlink" - if err := dentryFromVD(t, root).symlinkLowerAt(ctx, vfsObj, target, symlink); err != nil { - t.Fatalf("SymlinkAt: %v", err) - } - - fd, err := openVerityAt(ctx, vfsObj, root, symlink, linux.O_NOFOLLOW, linux.ModeRegular) - - if err != nil { - t.Fatalf("openVerityAt symlink: %v", err) - } - - enableVerity(ctx, t, fd) - - if _, err := vfsObj.ReadlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(symlink), - }); err != nil { - t.Fatalf("ReadlinkAt: %v", err) - } - - if tc.testWalk { - fileInSymlinkDirectory := symlink + "/verity-test-file" - // Ensure opening the verity enabled file in the symlink directory succeeds. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != nil { - t.Errorf("open enabled file failed: %v", err) - } - } - }) - } -} - -// TestDeletedSymlinkFileReadFails ensures that reading value of a deleted verity enabled -// symlink fails. -func TestDeletedSymlinkFileReadFails(t *testing.T) { - testCases := []struct { - name string - // The original symlink is unlinked if deleteLink is true. - deleteLink bool - // The Merkle tree file is renamed if deleteMerkleFile is true. - deleteMerkleFile bool - // The symlink target is a directory. - hasDirectoryTarget bool - // The symlink target is a directory and contains a regular file which will be - // used to test walking a symlink. - testWalk bool - }{ - { - name: "DeleteLinkRegularFile", - deleteLink: true, - deleteMerkleFile: false, - hasDirectoryTarget: false, - testWalk: false, - }, - { - name: "DeleteMerkleRegFile", - deleteLink: false, - deleteMerkleFile: true, - hasDirectoryTarget: false, - testWalk: false, - }, - { - name: "DeleteLinkAndMerkleRegFile", - deleteLink: true, - deleteMerkleFile: true, - hasDirectoryTarget: false, - testWalk: false, - }, - { - name: "DeleteLinkDirectory", - deleteLink: true, - deleteMerkleFile: false, - hasDirectoryTarget: true, - testWalk: false, - }, - { - name: "DeleteMerkleDirectory", - deleteLink: false, - deleteMerkleFile: true, - hasDirectoryTarget: true, - testWalk: false, - }, - { - name: "DeleteLinkAndMerkleDirectory", - deleteLink: true, - deleteMerkleFile: true, - hasDirectoryTarget: true, - testWalk: false, - }, - { - name: "DeleteLinkDirectoryWalk", - deleteLink: true, - deleteMerkleFile: false, - hasDirectoryTarget: true, - testWalk: true, - }, - { - name: "DeleteMerkleDirectoryWalk", - deleteLink: false, - deleteMerkleFile: true, - hasDirectoryTarget: true, - testWalk: true, - }, - { - name: "DeleteLinkAndMerkleDirectoryWalk", - deleteLink: true, - deleteMerkleFile: true, - hasDirectoryTarget: true, - testWalk: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if tc.testWalk && !tc.hasDirectoryTarget { - t.Fatalf("Invalid test case: hasDirectoryTarget can't be false when testing symlink walk") - } - - vfsObj, root, ctx, err := newVerityRoot(t, SHA256) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - var target string - if tc.hasDirectoryTarget { - target = "verity-test-dir" - if _, err := newDirFD(ctx, t, vfsObj, root, target, 0644); err != nil { - t.Fatalf("newDirFD: %v", err) - } - } else { - target = "verity-test-file" - if _, _, err := newFileFD(ctx, t, vfsObj, root, target, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - symlink := "verity-test-symlink" - if err := dentryFromVD(t, root).symlinkLowerAt(ctx, vfsObj, target, symlink); err != nil { - t.Fatalf("SymlinkAt: %v", err) - } - - fd, err := openVerityAt(ctx, vfsObj, root, symlink, linux.O_NOFOLLOW, linux.ModeRegular) - - if err != nil { - t.Fatalf("openVerityAt symlink: %v", err) - } - - if tc.testWalk { - fileInTargetDirectory := target + "/" + "verity-test-file" - if _, _, err := newFileFD(ctx, t, vfsObj, root, fileInTargetDirectory, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - enableVerity(ctx, t, fd) - - if tc.deleteLink { - if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, symlink); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - if tc.deleteMerkleFile { - if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, symlink); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - } - if _, err := vfsObj.ReadlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(symlink), - }); !linuxerr.Equals(linuxerr.EIO, err) { - t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err) - } - - if tc.testWalk { - fileInSymlinkDirectory := symlink + "/verity-test-file" - // Ensure opening the verity enabled file in the symlink directory fails. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("Open succeeded with modified symlink: %v", err) - } - } - }) - } -} - -// TestModifiedSymlinkFileReadFails ensures that reading value of a modified verity enabled -// symlink fails. -func TestModifiedSymlinkFileReadFails(t *testing.T) { - testCases := []struct { - name string - // The symlink target is a directory. - hasDirectoryTarget bool - // The symlink target is a directory and contains a regular file which will be - // used to test walking a symlink. - testWalk bool - }{ - { - name: "RegularFileTarget", - hasDirectoryTarget: false, - testWalk: false, - }, - { - name: "DirectoryTarget", - hasDirectoryTarget: true, - testWalk: false, - }, - { - name: "RegularFileInSymlinkDirectory", - hasDirectoryTarget: true, - testWalk: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if tc.testWalk && !tc.hasDirectoryTarget { - t.Fatalf("Invalid test case: hasDirectoryTarget can't be false when testing symlink walk") - } - - vfsObj, root, ctx, err := newVerityRoot(t, SHA256) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - var target string - if tc.hasDirectoryTarget { - target = "verity-test-dir" - if _, err := newDirFD(ctx, t, vfsObj, root, target, 0644); err != nil { - t.Fatalf("newDirFD: %v", err) - } - } else { - target = "verity-test-file" - if _, _, err := newFileFD(ctx, t, vfsObj, root, target, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - // Create symlink which points to target file. - symlink := "verity-test-symlink" - if err := dentryFromVD(t, root).symlinkLowerAt(ctx, vfsObj, target, symlink); err != nil { - t.Fatalf("SymlinkAt: %v", err) - } - - // Open symlink file to get the fd for ioctl in new step. - fd, err := openVerityAt(ctx, vfsObj, root, symlink, linux.O_NOFOLLOW, linux.ModeRegular) - if err != nil { - t.Fatalf("OpenAt symlink: %v", err) - } - - if tc.testWalk { - fileInTargetDirectory := target + "/" + "verity-test-file" - if _, _, err := newFileFD(ctx, t, vfsObj, root, fileInTargetDirectory, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - enableVerity(ctx, t, fd) - - var newTarget string - if tc.hasDirectoryTarget { - newTarget = "verity-test-dir-new" - if _, err := newDirFD(ctx, t, vfsObj, root, newTarget, 0644); err != nil { - t.Fatalf("newDirFD: %v", err) - } - } else { - newTarget = "verity-test-file-new" - if _, _, err := newFileFD(ctx, t, vfsObj, root, newTarget, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - } - - // Unlink symlink->target. - if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, symlink); err != nil { - t.Fatalf("UnlinkAt: %v", err) - } - - // Link symlink->newTarget. - if err := dentryFromVD(t, root).symlinkLowerAt(ctx, vfsObj, newTarget, symlink); err != nil { - t.Fatalf("SymlinkAt: %v", err) - } - - // Freshen lower dentry for symlink. - symlinkVD, err := vfsObj.GetDentryAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(symlink), - }, &vfs.GetDentryOptions{}) - if err != nil { - t.Fatalf("Failed to get symlink dentry: %v", err) - } - symlinkDentry := dentryFromVD(t, symlinkVD) - - symlinkLowerVD, err := dentryFromVD(t, root).getLowerAt(ctx, vfsObj, symlink) - if err != nil { - t.Fatalf("Failed to get symlink lower dentry: %v", err) - } - symlinkDentry.lowerVD = symlinkLowerVD - - // Verify ReadlinkAt() fails. - if _, err := vfsObj.ReadlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(symlink), - }); !linuxerr.Equals(linuxerr.EIO, err) { - t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err) - } - - if tc.testWalk { - fileInSymlinkDirectory := symlink + "/verity-test-file" - // Ensure opening the verity enabled file in the symlink directory fails. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("Open succeeded with modified symlink: %v", err) - } - } - }) - } -} diff --git a/pkg/sentry/fsmetric/BUILD b/pkg/sentry/fsmetric/BUILD deleted file mode 100644 index 4e86fbdd8..000000000 --- a/pkg/sentry/fsmetric/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "fsmetric", - srcs = ["fsmetric.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/metric"], -) diff --git a/pkg/sentry/fsmetric/fsmetric_state_autogen.go b/pkg/sentry/fsmetric/fsmetric_state_autogen.go new file mode 100644 index 000000000..61975bbb4 --- /dev/null +++ b/pkg/sentry/fsmetric/fsmetric_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fsmetric diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD deleted file mode 100644 index e6933aa70..000000000 --- a/pkg/sentry/hostcpu/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "hostcpu", - srcs = [ - "getcpu_amd64.s", - "getcpu_arm64.s", - "hostcpu.go", - ], - visibility = ["//:sandbox"], -) - -go_test( - name = "hostcpu_test", - size = "small", - srcs = ["hostcpu_test.go"], - library = ":hostcpu", -) diff --git a/pkg/sentry/hostcpu/hostcpu_state_autogen.go b/pkg/sentry/hostcpu/hostcpu_state_autogen.go new file mode 100644 index 000000000..97d33d8bf --- /dev/null +++ b/pkg/sentry/hostcpu/hostcpu_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostcpu diff --git a/pkg/sentry/hostcpu/hostcpu_test.go b/pkg/sentry/hostcpu/hostcpu_test.go deleted file mode 100644 index 7d6885c9e..000000000 --- a/pkg/sentry/hostcpu/hostcpu_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package hostcpu - -import ( - "fmt" - "testing" -) - -func TestMaxValueInLinuxBitmap(t *testing.T) { - for _, test := range []struct { - str string - max uint64 - }{ - {"0", 0}, - {"0\n", 0}, - {"0,2", 2}, - {"0-63", 63}, - {"0-3,8-11", 11}, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - max, err := maxValueInLinuxBitmap(test.str) - if err != nil || max != test.max { - t.Errorf("maxValueInLinuxBitmap: got (%d, %v), wanted (%d, nil)", max, err, test.max) - } - }) - } -} - -func TestMaxValueInLinuxBitmapErrors(t *testing.T) { - for _, str := range []string{"", "\n"} { - t.Run(fmt.Sprintf("%q", str), func(t *testing.T) { - max, err := maxValueInLinuxBitmap(str) - if err == nil { - t.Errorf("maxValueInLinuxBitmap: got (%d, nil), wanted (_, error)", max) - } - t.Log(err) - }) - } -} diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD deleted file mode 100644 index db3b0d0a0..000000000 --- a/pkg/sentry/hostfd/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "hostfd", - srcs = [ - "hostfd.go", - "hostfd_linux.go", - "hostfd_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/log", - "//pkg/safemem", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/hostfd/hostfd_linux_state_autogen.go b/pkg/sentry/hostfd/hostfd_linux_state_autogen.go new file mode 100644 index 000000000..24e7ba400 --- /dev/null +++ b/pkg/sentry/hostfd/hostfd_linux_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package hostfd diff --git a/pkg/sentry/hostfd/hostfd_state_autogen.go b/pkg/sentry/hostfd/hostfd_state_autogen.go new file mode 100644 index 000000000..9033424d5 --- /dev/null +++ b/pkg/sentry/hostfd/hostfd_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostfd diff --git a/pkg/sentry/hostfd/hostfd_unsafe_state_autogen.go b/pkg/sentry/hostfd/hostfd_unsafe_state_autogen.go new file mode 100644 index 000000000..9033424d5 --- /dev/null +++ b/pkg/sentry/hostfd/hostfd_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostfd diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD deleted file mode 100644 index 66fa1ad40..000000000 --- a/pkg/sentry/hostmm/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hostmm", - srcs = [ - "cgroup.go", - "hostmm.go", - "membarrier.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/fd", - "//pkg/hostarch", - "//pkg/log", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/hostmm/hostmm_state_autogen.go b/pkg/sentry/hostmm/hostmm_state_autogen.go new file mode 100644 index 000000000..925c56e14 --- /dev/null +++ b/pkg/sentry/hostmm/hostmm_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostmm diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD deleted file mode 100644 index 5bba9de0b..000000000 --- a/pkg/sentry/inet/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -go_library( - name = "inet", - srcs = [ - "context.go", - "inet.go", - "namespace.go", - "test_stack.go", - ], - deps = [ - "//pkg/context", - "//pkg/tcpip", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/sentry/inet/inet_state_autogen.go b/pkg/sentry/inet/inet_state_autogen.go new file mode 100644 index 000000000..30af4fb91 --- /dev/null +++ b/pkg/sentry/inet/inet_state_autogen.go @@ -0,0 +1,70 @@ +// automatically generated by stateify. + +package inet + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *TCPBufferSize) StateTypeName() string { + return "pkg/sentry/inet.TCPBufferSize" +} + +func (t *TCPBufferSize) StateFields() []string { + return []string{ + "Min", + "Default", + "Max", + } +} + +func (t *TCPBufferSize) beforeSave() {} + +// +checklocksignore +func (t *TCPBufferSize) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Min) + stateSinkObject.Save(1, &t.Default) + stateSinkObject.Save(2, &t.Max) +} + +func (t *TCPBufferSize) afterLoad() {} + +// +checklocksignore +func (t *TCPBufferSize) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Min) + stateSourceObject.Load(1, &t.Default) + stateSourceObject.Load(2, &t.Max) +} + +func (n *Namespace) StateTypeName() string { + return "pkg/sentry/inet.Namespace" +} + +func (n *Namespace) StateFields() []string { + return []string{ + "creator", + "isRoot", + } +} + +func (n *Namespace) beforeSave() {} + +// +checklocksignore +func (n *Namespace) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.creator) + stateSinkObject.Save(1, &n.isRoot) +} + +// +checklocksignore +func (n *Namespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &n.creator) + stateSourceObject.Load(1, &n.isRoot) + stateSourceObject.AfterLoad(n.afterLoad) +} + +func init() { + state.Register((*TCPBufferSize)(nil)) + state.Register((*Namespace)(nil)) +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD deleted file mode 100644 index 816e60329..000000000 --- a/pkg/sentry/kernel/BUILD +++ /dev/null @@ -1,318 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "pending_signals_list", - out = "pending_signals_list.go", - package = "kernel", - prefix = "pendingSignal", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*pendingSignal", - "Linker": "*pendingSignal", - }, -) - -go_template_instance( - name = "process_group_list", - out = "process_group_list.go", - package = "kernel", - prefix = "processGroup", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*ProcessGroup", - "Linker": "*ProcessGroup", - }, -) - -go_template_instance( - name = "seqatomic_taskgoroutineschedinfo", - out = "seqatomic_taskgoroutineschedinfo_unsafe.go", - package = "kernel", - suffix = "TaskGoroutineSchedInfo", - template = "//pkg/sync/seqatomic:generic_seqatomic", - types = { - "Value": "TaskGoroutineSchedInfo", - }, -) - -go_template_instance( - name = "session_list", - out = "session_list.go", - package = "kernel", - prefix = "session", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Session", - "Linker": "*Session", - }, -) - -go_template_instance( - name = "task_list", - out = "task_list.go", - package = "kernel", - prefix = "task", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Task", - "Linker": "*Task", - }, -) - -go_template_instance( - name = "socket_list", - out = "socket_list.go", - package = "kernel", - prefix = "socket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*SocketRecordVFS1", - "Linker": "*SocketRecordVFS1", - }, -) - -go_template_instance( - name = "fd_table_refs", - out = "fd_table_refs.go", - package = "kernel", - prefix = "FDTable", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "FDTable", - }, -) - -go_template_instance( - name = "fs_context_refs", - out = "fs_context_refs.go", - package = "kernel", - prefix = "FSContext", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "FSContext", - }, -) - -go_template_instance( - name = "ipc_namespace_refs", - out = "ipc_namespace_refs.go", - package = "kernel", - prefix = "IPCNamespace", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "IPCNamespace", - }, -) - -go_template_instance( - name = "process_group_refs", - out = "process_group_refs.go", - package = "kernel", - prefix = "ProcessGroup", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "ProcessGroup", - }, -) - -go_template_instance( - name = "session_refs", - out = "session_refs.go", - package = "kernel", - prefix = "Session", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "Session", - }, -) - -proto_library( - name = "uncaught_signal", - srcs = ["uncaught_signal.proto"], - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_proto"], -) - -go_library( - name = "kernel", - srcs = [ - "abstract_socket_namespace.go", - "aio.go", - "cgroup.go", - "context.go", - "fd_table.go", - "fd_table_refs.go", - "fd_table_unsafe.go", - "fs_context.go", - "fs_context_refs.go", - "ipc_namespace.go", - "ipc_namespace_refs.go", - "kcov.go", - "kcov_unsafe.go", - "kernel.go", - "kernel_opts.go", - "kernel_state.go", - "pending_signals.go", - "pending_signals_list.go", - "pending_signals_state.go", - "posixtimer.go", - "process_group_list.go", - "process_group_refs.go", - "ptrace.go", - "ptrace_amd64.go", - "ptrace_arm64.go", - "rseq.go", - "seccomp.go", - "seqatomic_taskgoroutineschedinfo_unsafe.go", - "session_list.go", - "session_refs.go", - "sessions.go", - "signal.go", - "signal_handlers.go", - "socket_list.go", - "syscalls.go", - "syscalls_state.go", - "syslog.go", - "task.go", - "task_acct.go", - "task_block.go", - "task_cgroup.go", - "task_clone.go", - "task_context.go", - "task_exec.go", - "task_exit.go", - "task_futex.go", - "task_identity.go", - "task_image.go", - "task_list.go", - "task_log.go", - "task_net.go", - "task_run.go", - "task_sched.go", - "task_signals.go", - "task_start.go", - "task_stop.go", - "task_syscall.go", - "task_usermem.go", - "task_work.go", - "thread_group.go", - "threads.go", - "timekeeper.go", - "timekeeper_state.go", - "tty.go", - "uts_namespace.go", - "vdso.go", - "version.go", - ], - imports = [ - "gvisor.dev/gvisor/pkg/bpf", - "gvisor.dev/gvisor/pkg/sentry/device", - "gvisor.dev/gvisor/pkg/tcpip", - ], - marshal = True, - visibility = ["//:sandbox"], - deps = [ - ":uncaught_signal_go_proto", - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/abi/linux/errno", - "//pkg/amutex", - "//pkg/bitmap", - "//pkg/bits", - "//pkg/bpf", - "//pkg/cleanup", - "//pkg/context", - "//pkg/coverage", - "//pkg/cpuid", - "//pkg/errors", - "//pkg/errors/linuxerr", - "//pkg/eventchannel", - "//pkg/fspath", - "//pkg/goid", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/metric", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fs/timerfd", - "//pkg/sentry/fsbridge", - "//pkg/sentry/fsimpl/kernfs", - "//pkg/sentry/fsimpl/pipefs", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/fsimpl/timerfd", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/hostcpu", - "//pkg/sentry/inet", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/futex", - "//pkg/sentry/kernel/msgqueue", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/kernel/semaphore", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/time", - "//pkg/sentry/unimpl", - "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/state", - "//pkg/state/statefile", - "//pkg/state/wire", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "kernel_test", - size = "small", - srcs = [ - "fd_table_test.go", - "table_test.go", - "task_test.go", - "timekeeper_test.go", - ], - library = ":kernel", - deps = [ - "//pkg/abi", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/arch", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/filetest", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/time", - "//pkg/sentry/usage", - "//pkg/sync", - ], -) diff --git a/pkg/sentry/kernel/README.md b/pkg/sentry/kernel/README.md deleted file mode 100644 index 427311be8..000000000 --- a/pkg/sentry/kernel/README.md +++ /dev/null @@ -1,108 +0,0 @@ -This package contains: - -- A (partial) emulation of the "core Linux kernel", which governs task - execution and scheduling, system call dispatch, and signal handling. See - below for details. - -- The top-level interface for the sentry's Linux kernel emulation in general, - used by the `main` function of all versions of the sentry. This interface - revolves around the `Env` type (defined in `kernel.go`). - -# Background - -In Linux, each schedulable context is referred to interchangeably as a "task" or -"thread". Tasks can be divided into userspace and kernel tasks. In the sentry, -scheduling is managed by the Go runtime, so each schedulable context is a -goroutine; only "userspace" (application) contexts are referred to as tasks, and -represented by Task objects. (From this point forward, "task" refers to the -sentry's notion of a task unless otherwise specified.) - -At a high level, Linux application threads can be thought of as repeating a "run -loop": - -- Some amount of application code is executed in userspace. - -- A trap (explicit syscall invocation, hardware interrupt or exception, etc.) - causes control flow to switch to the kernel. - -- Some amount of kernel code is executed in kernelspace, e.g. to handle the - cause of the trap. - -- The kernel "returns from the trap" into application code. - -Analogously, each task in the sentry is associated with a *task goroutine* that -executes that task's run loop (`Task.run` in `task_run.go`). However, the -sentry's task run loop differs in structure in order to support saving execution -state to, and resuming execution from, checkpoints. - -While in kernelspace, a Linux thread can be descheduled (cease execution) in a -variety of ways: - -- It can yield or be preempted, becoming temporarily descheduled but still - runnable. At present, the sentry delegates scheduling of runnable threads to - the Go runtime. - -- It can exit, becoming permanently descheduled. The sentry's equivalent is - returning from `Task.run`, terminating the task goroutine. - -- It can enter interruptible sleep, a state in which it can be woken by a - caller-defined wakeup or the receipt of a signal. In the sentry, - interruptible sleep (which is ambiguously referred to as *blocking*) is - implemented by making all events that can end blocking (including signal - notifications) communicated via Go channels and using `select` to multiplex - wakeup sources; see `task_block.go`. - -- It can enter uninterruptible sleep, a state in which it can only be woken by - a caller-defined wakeup. Killable sleep is a closely related variant in - which the task can also be woken by SIGKILL. (These definitions also include - Linux's "group-stopped" (`TASK_STOPPED`) and "ptrace-stopped" - (`TASK_TRACED`) states.) - -To maximize compatibility with Linux, sentry checkpointing appears as a spurious -signal-delivery interrupt on all tasks; interrupted system calls return `EINTR` -or are automatically restarted as usual. However, these semantics require that -uninterruptible and killable sleeps do not appear to be interrupted. In other -words, the state of the task, including its progress through the interrupted -operation, must be preserved by checkpointing. For many such sleeps, the wakeup -condition is application-controlled, making it infeasible to wait for the sleep -to end before checkpointing. Instead, we must support checkpointing progress -through sleeping operations. - -# Implementation - -We break the task's control flow graph into *states*, delimited by: - -1. Points where uninterruptible and killable sleeps may occur. For example, - there exists a state boundary between signal dequeueing and signal delivery - because there may be an intervening ptrace signal-delivery-stop. - -2. Points where sleep-induced branches may "rejoin" normal execution. For - example, the syscall exit state exists because it can be reached immediately - following a synchronous syscall, or after a task that is sleeping in - `execve()` or `vfork()` resumes execution. - -3. Points containing large branches. This is strictly for organizational - purposes. For example, the state that processes interrupt-signaled - conditions is kept separate from the main "app" state to reduce the size of - the latter. - -4. `SyscallReinvoke`, which does not correspond to anything in Linux, and - exists solely to serve the autosave feature. - -![dot -Tpng -Goverlap=false -orun_states.png run_states.dot](g3doc/run_states.png "Task control flow graph") - -States before which a stop may occur are represented as implementations of the -`taskRunState` interface named `run(state)`, allowing them to be saved and -restored. States that cannot be immediately preceded by a stop are simply `Task` -methods named `do(state)`. - -Conditions that can require task goroutines to cease execution for unknown -lengths of time are called *stops*. Stops are divided into *internal stops*, -which are stops whose start and end conditions are implemented within the -sentry, and *external stops*, which are stops whose start and end conditions are -not known to the sentry. Hence all uninterruptible and killable sleeps are -internal stops, and the existence of a pending checkpoint operation is an -external stop. Internal stops are reified into instances of the `TaskStop` type, -while external stops are merely counted. The task run loop alternates between -checking for stops and advancing the task's state. This allows checkpointing to -hold tasks in a stopped state while waiting for all tasks in the system to stop. diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD deleted file mode 100644 index 9aa03f506..000000000 --- a/pkg/sentry/kernel/auth/BUILD +++ /dev/null @@ -1,70 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_credentials", - out = "atomicptr_credentials_unsafe.go", - package = "auth", - suffix = "Credentials", - template = "//pkg/sync/atomicptr:generic_atomicptr", - types = { - "Value": "Credentials", - }, -) - -go_template_instance( - name = "id_map_range", - out = "id_map_range.go", - package = "auth", - prefix = "idMap", - template = "//pkg/segment:generic_range", - types = { - "T": "uint32", - }, -) - -go_template_instance( - name = "id_map_set", - out = "id_map_set.go", - consts = { - "minDegree": "3", - }, - package = "auth", - prefix = "idMap", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint32", - "Range": "idMapRange", - "Value": "uint32", - "Functions": "idMapFunctions", - }, -) - -go_library( - name = "auth", - srcs = [ - "atomicptr_credentials_unsafe.go", - "auth.go", - "capability_set.go", - "context.go", - "credentials.go", - "id.go", - "id_map.go", - "id_map_functions.go", - "id_map_range.go", - "id_map_set.go", - "user_namespace.go", - ], - marshal = True, - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/sync", - ], -) diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sentry/kernel/auth/atomicptr_credentials_unsafe.go index 82b6df18c..4535c958f 100644 --- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go +++ b/pkg/sentry/kernel/auth/atomicptr_credentials_unsafe.go @@ -1,20 +1,10 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package seqatomic doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package seqatomic +package auth import ( "sync/atomic" "unsafe" ) -// Value is a required type parameter. -type Value struct{} - // An AtomicPtr is a pointer to a value of type Value that can be atomically // loaded and stored. The zero value of an AtomicPtr represents nil. // @@ -23,25 +13,25 @@ type Value struct{} // this case, do `dst.Store(src.Load())` instead. // // +stateify savable -type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` +type AtomicPtrCredentials struct { + ptr unsafe.Pointer `state:".(*Credentials)"` } -func (p *AtomicPtr) savePtr() *Value { +func (p *AtomicPtrCredentials) savePtr() *Credentials { return p.Load() } -func (p *AtomicPtr) loadPtr(v *Value) { +func (p *AtomicPtrCredentials) loadPtr(v *Credentials) { p.Store(v) } // Load returns the value set by the most recent Store. It returns nil if there // has been no previous call to Store. -func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) +func (p *AtomicPtrCredentials) Load() *Credentials { + return (*Credentials)(atomic.LoadPointer(&p.ptr)) } // Store sets the value returned by Load to x. -func (p *AtomicPtr) Store(x *Value) { +func (p *AtomicPtrCredentials) Store(x *Credentials) { atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) } diff --git a/pkg/sentry/kernel/auth/auth_abi_autogen_unsafe.go b/pkg/sentry/kernel/auth/auth_abi_autogen_unsafe.go new file mode 100644 index 000000000..393dcea8c --- /dev/null +++ b/pkg/sentry/kernel/auth/auth_abi_autogen_unsafe.go @@ -0,0 +1,274 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package auth + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*GID)(nil) +var _ marshal.Marshallable = (*UID)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (gid *GID) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (gid *GID) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*gid)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (gid *GID) UnmarshalBytes(src []byte) { + *gid = GID(uint32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (gid *GID) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (gid *GID) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(gid), uintptr(gid.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (gid *GID) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(gid), unsafe.Pointer(&src[0]), uintptr(gid.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (gid *GID) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(gid))) + hdr.Len = gid.SizeBytes() + hdr.Cap = gid.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that gid + // must live until the use above. + runtime.KeepAlive(gid) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (gid *GID) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return gid.CopyOutN(cc, addr, gid.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (gid *GID) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(gid))) + hdr.Len = gid.SizeBytes() + hdr.Cap = gid.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that gid + // must live until the use above. + runtime.KeepAlive(gid) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (gid *GID) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(gid))) + hdr.Len = gid.SizeBytes() + hdr.Cap = gid.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that gid + // must live until the use above. + runtime.KeepAlive(gid) // escapes: replaced by intrinsic. + return int64(length), err +} + +// CopyGIDSliceIn copies in a slice of GID objects from the task's memory. +//go:nosplit +func CopyGIDSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []GID) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*GID)(nil).SizeBytes() + + ptr := unsafe.Pointer(&dst) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that dst + // must live until the use above. + runtime.KeepAlive(dst) // escapes: replaced by intrinsic. + return length, err +} + +// CopyGIDSliceOut copies a slice of GID objects to the task's memory. +//go:nosplit +func CopyGIDSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []GID) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*GID)(nil).SizeBytes() + + ptr := unsafe.Pointer(&src) + val := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data)) + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(val) + hdr.Len = size * count + hdr.Cap = size * count + + length, err := cc.CopyOutBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that src + // must live until the use above. + runtime.KeepAlive(src) // escapes: replaced by intrinsic. + return length, err +} + +// MarshalUnsafeGIDSlice is like GID.MarshalUnsafe, but for a []GID. +func MarshalUnsafeGIDSlice(src []GID, dst []byte) (int, error) { + count := len(src) + if count == 0 { + return 0, nil + } + size := (*GID)(nil).SizeBytes() + + dst = dst[:size*count] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst))) + return size*count, nil +} + +// UnmarshalUnsafeGIDSlice is like GID.UnmarshalUnsafe, but for a []GID. +func UnmarshalUnsafeGIDSlice(dst []GID, src []byte) (int, error) { + count := len(dst) + if count == 0 { + return 0, nil + } + size := (*GID)(nil).SizeBytes() + + src = src[:(size*count)] + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))) + return size*count, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (uid *UID) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (uid *UID) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*uid)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (uid *UID) UnmarshalBytes(src []byte) { + *uid = UID(uint32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (uid *UID) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (uid *UID) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(uid), uintptr(uid.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (uid *UID) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(uid), unsafe.Pointer(&src[0]), uintptr(uid.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (uid *UID) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(uid))) + hdr.Len = uid.SizeBytes() + hdr.Cap = uid.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that uid + // must live until the use above. + runtime.KeepAlive(uid) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (uid *UID) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return uid.CopyOutN(cc, addr, uid.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (uid *UID) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(uid))) + hdr.Len = uid.SizeBytes() + hdr.Cap = uid.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that uid + // must live until the use above. + runtime.KeepAlive(uid) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (uid *UID) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(uid))) + hdr.Len = uid.SizeBytes() + hdr.Cap = uid.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that uid + // must live until the use above. + runtime.KeepAlive(uid) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/kernel/auth/auth_state_autogen.go b/pkg/sentry/kernel/auth/auth_state_autogen.go new file mode 100644 index 000000000..dea316420 --- /dev/null +++ b/pkg/sentry/kernel/auth/auth_state_autogen.go @@ -0,0 +1,281 @@ +// automatically generated by stateify. + +package auth + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (c *Credentials) StateTypeName() string { + return "pkg/sentry/kernel/auth.Credentials" +} + +func (c *Credentials) StateFields() []string { + return []string{ + "RealKUID", + "EffectiveKUID", + "SavedKUID", + "RealKGID", + "EffectiveKGID", + "SavedKGID", + "ExtraKGIDs", + "PermittedCaps", + "InheritableCaps", + "EffectiveCaps", + "BoundingCaps", + "KeepCaps", + "UserNamespace", + } +} + +func (c *Credentials) beforeSave() {} + +// +checklocksignore +func (c *Credentials) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.RealKUID) + stateSinkObject.Save(1, &c.EffectiveKUID) + stateSinkObject.Save(2, &c.SavedKUID) + stateSinkObject.Save(3, &c.RealKGID) + stateSinkObject.Save(4, &c.EffectiveKGID) + stateSinkObject.Save(5, &c.SavedKGID) + stateSinkObject.Save(6, &c.ExtraKGIDs) + stateSinkObject.Save(7, &c.PermittedCaps) + stateSinkObject.Save(8, &c.InheritableCaps) + stateSinkObject.Save(9, &c.EffectiveCaps) + stateSinkObject.Save(10, &c.BoundingCaps) + stateSinkObject.Save(11, &c.KeepCaps) + stateSinkObject.Save(12, &c.UserNamespace) +} + +func (c *Credentials) afterLoad() {} + +// +checklocksignore +func (c *Credentials) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.RealKUID) + stateSourceObject.Load(1, &c.EffectiveKUID) + stateSourceObject.Load(2, &c.SavedKUID) + stateSourceObject.Load(3, &c.RealKGID) + stateSourceObject.Load(4, &c.EffectiveKGID) + stateSourceObject.Load(5, &c.SavedKGID) + stateSourceObject.Load(6, &c.ExtraKGIDs) + stateSourceObject.Load(7, &c.PermittedCaps) + stateSourceObject.Load(8, &c.InheritableCaps) + stateSourceObject.Load(9, &c.EffectiveCaps) + stateSourceObject.Load(10, &c.BoundingCaps) + stateSourceObject.Load(11, &c.KeepCaps) + stateSourceObject.Load(12, &c.UserNamespace) +} + +func (i *IDMapEntry) StateTypeName() string { + return "pkg/sentry/kernel/auth.IDMapEntry" +} + +func (i *IDMapEntry) StateFields() []string { + return []string{ + "FirstID", + "FirstParentID", + "Length", + } +} + +func (i *IDMapEntry) beforeSave() {} + +// +checklocksignore +func (i *IDMapEntry) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.FirstID) + stateSinkObject.Save(1, &i.FirstParentID) + stateSinkObject.Save(2, &i.Length) +} + +func (i *IDMapEntry) afterLoad() {} + +// +checklocksignore +func (i *IDMapEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.FirstID) + stateSourceObject.Load(1, &i.FirstParentID) + stateSourceObject.Load(2, &i.Length) +} + +func (r *idMapRange) StateTypeName() string { + return "pkg/sentry/kernel/auth.idMapRange" +} + +func (r *idMapRange) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (r *idMapRange) beforeSave() {} + +// +checklocksignore +func (r *idMapRange) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) +} + +func (r *idMapRange) afterLoad() {} + +// +checklocksignore +func (r *idMapRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) +} + +func (s *idMapSet) StateTypeName() string { + return "pkg/sentry/kernel/auth.idMapSet" +} + +func (s *idMapSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *idMapSet) beforeSave() {} + +// +checklocksignore +func (s *idMapSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *idMapSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *idMapSet) afterLoad() {} + +// +checklocksignore +func (s *idMapSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*idMapSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*idMapSegmentDataSlices)) }) +} + +func (n *idMapnode) StateTypeName() string { + return "pkg/sentry/kernel/auth.idMapnode" +} + +func (n *idMapnode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *idMapnode) beforeSave() {} + +// +checklocksignore +func (n *idMapnode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *idMapnode) afterLoad() {} + +// +checklocksignore +func (n *idMapnode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (i *idMapSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/kernel/auth.idMapSegmentDataSlices" +} + +func (i *idMapSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (i *idMapSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (i *idMapSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.Start) + stateSinkObject.Save(1, &i.End) + stateSinkObject.Save(2, &i.Values) +} + +func (i *idMapSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (i *idMapSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.Start) + stateSourceObject.Load(1, &i.End) + stateSourceObject.Load(2, &i.Values) +} + +func (ns *UserNamespace) StateTypeName() string { + return "pkg/sentry/kernel/auth.UserNamespace" +} + +func (ns *UserNamespace) StateFields() []string { + return []string{ + "parent", + "owner", + "uidMapFromParent", + "uidMapToParent", + "gidMapFromParent", + "gidMapToParent", + } +} + +func (ns *UserNamespace) beforeSave() {} + +// +checklocksignore +func (ns *UserNamespace) StateSave(stateSinkObject state.Sink) { + ns.beforeSave() + stateSinkObject.Save(0, &ns.parent) + stateSinkObject.Save(1, &ns.owner) + stateSinkObject.Save(2, &ns.uidMapFromParent) + stateSinkObject.Save(3, &ns.uidMapToParent) + stateSinkObject.Save(4, &ns.gidMapFromParent) + stateSinkObject.Save(5, &ns.gidMapToParent) +} + +func (ns *UserNamespace) afterLoad() {} + +// +checklocksignore +func (ns *UserNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ns.parent) + stateSourceObject.Load(1, &ns.owner) + stateSourceObject.Load(2, &ns.uidMapFromParent) + stateSourceObject.Load(3, &ns.uidMapToParent) + stateSourceObject.Load(4, &ns.gidMapFromParent) + stateSourceObject.Load(5, &ns.gidMapToParent) +} + +func init() { + state.Register((*Credentials)(nil)) + state.Register((*IDMapEntry)(nil)) + state.Register((*idMapRange)(nil)) + state.Register((*idMapSet)(nil)) + state.Register((*idMapnode)(nil)) + state.Register((*idMapSegmentDataSlices)(nil)) + state.Register((*UserNamespace)(nil)) +} diff --git a/pkg/sentry/kernel/auth/auth_unsafe_abi_autogen_unsafe.go b/pkg/sentry/kernel/auth/auth_unsafe_abi_autogen_unsafe.go new file mode 100644 index 000000000..ebcd3911b --- /dev/null +++ b/pkg/sentry/kernel/auth/auth_unsafe_abi_autogen_unsafe.go @@ -0,0 +1,7 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package auth + +import ( +) + diff --git a/pkg/sentry/kernel/auth/auth_unsafe_state_autogen.go b/pkg/sentry/kernel/auth/auth_unsafe_state_autogen.go new file mode 100644 index 000000000..28674a087 --- /dev/null +++ b/pkg/sentry/kernel/auth/auth_unsafe_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package auth + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *AtomicPtrCredentials) StateTypeName() string { + return "pkg/sentry/kernel/auth.AtomicPtrCredentials" +} + +func (p *AtomicPtrCredentials) StateFields() []string { + return []string{ + "ptr", + } +} + +func (p *AtomicPtrCredentials) beforeSave() {} + +// +checklocksignore +func (p *AtomicPtrCredentials) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var ptrValue *Credentials + ptrValue = p.savePtr() + stateSinkObject.SaveValue(0, ptrValue) +} + +func (p *AtomicPtrCredentials) afterLoad() {} + +// +checklocksignore +func (p *AtomicPtrCredentials) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*Credentials), func(y interface{}) { p.loadPtr(y.(*Credentials)) }) +} + +func init() { + state.Register((*AtomicPtrCredentials)(nil)) +} diff --git a/pkg/sentry/kernel/auth/id_map_range.go b/pkg/sentry/kernel/auth/id_map_range.go new file mode 100644 index 000000000..19d542716 --- /dev/null +++ b/pkg/sentry/kernel/auth/id_map_range.go @@ -0,0 +1,76 @@ +package auth + +// A Range represents a contiguous range of T. +// +// +stateify savable +type idMapRange struct { + // Start is the inclusive start of the range. + Start uint32 + + // End is the exclusive end of the range. + End uint32 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +// +//go:nosplit +func (r idMapRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +// +//go:nosplit +func (r idMapRange) Length() uint32 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +// +//go:nosplit +func (r idMapRange) Contains(x uint32) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +// +//go:nosplit +func (r idMapRange) Overlaps(r2 idMapRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +// +//go:nosplit +func (r idMapRange) IsSupersetOf(r2 idMapRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +// +//go:nosplit +func (r idMapRange) Intersect(r2 idMapRange) idMapRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +// +//go:nosplit +func (r idMapRange) CanSplitAt(x uint32) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/kernel/auth/id_map_set.go b/pkg/sentry/kernel/auth/id_map_set.go new file mode 100644 index 000000000..479753981 --- /dev/null +++ b/pkg/sentry/kernel/auth/id_map_set.go @@ -0,0 +1,1643 @@ +package auth + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const idMaptrackGaps = 0 + +var _ = uint8(idMaptrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type idMapdynamicGap [idMaptrackGaps]uint32 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *idMapdynamicGap) Get() uint32 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *idMapdynamicGap) Set(v uint32) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + idMapminDegree = 3 + + idMapmaxDegree = 2 * idMapminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type idMapSet struct { + root idMapnode `state:".(*idMapSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *idMapSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *idMapSet) IsEmptyRange(r idMapRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *idMapSet) Span() uint32 { + var sz uint32 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *idMapSet) SpanRange(r idMapRange) uint32 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint32 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *idMapSet) FirstSegment() idMapIterator { + if s.root.nrSegments == 0 { + return idMapIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *idMapSet) LastSegment() idMapIterator { + if s.root.nrSegments == 0 { + return idMapIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *idMapSet) FirstGap() idMapGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return idMapGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *idMapSet) LastGap() idMapGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return idMapGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *idMapSet) Find(key uint32) (idMapIterator, idMapGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return idMapIterator{n, i}, idMapGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return idMapIterator{}, idMapGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *idMapSet) FindSegment(key uint32) idMapIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *idMapSet) LowerBoundSegment(min uint32) idMapIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *idMapSet) UpperBoundSegment(max uint32) idMapIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *idMapSet) FindGap(key uint32) idMapGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *idMapSet) LowerBoundGap(min uint32) idMapGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *idMapSet) UpperBoundGap(max uint32) idMapGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *idMapSet) Add(r idMapRange, val uint32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *idMapSet) AddWithoutMerging(r idMapRange, val uint32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *idMapSet) Insert(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (idMapFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := idMaptrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (idMapFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (idMapFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := idMaptrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *idMapSet) InsertWithoutMerging(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *idMapSet) InsertWithoutMergingUnchecked(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := idMaptrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return idMapIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *idMapSet) Remove(seg idMapIterator) idMapGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if idMaptrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + idMapFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if idMaptrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(idMapGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *idMapSet) RemoveAll() { + s.root = idMapnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *idMapSet) RemoveRange(r idMapRange) idMapGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *idMapSet) Merge(first, second idMapIterator) idMapIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *idMapSet) MergeUnchecked(first, second idMapIterator) idMapIterator { + if first.End() == second.Start() { + if mval, ok := (idMapFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return idMapIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *idMapSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *idMapSet) MergeRange(r idMapRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *idMapSet) MergeAdjacent(r idMapRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *idMapSet) Split(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *idMapSet) SplitUnchecked(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) { + val1, val2 := (idMapFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), idMapRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *idMapSet) SplitAt(split uint32) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *idMapSet) Isolate(seg idMapIterator, r idMapRange) idMapIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *idMapSet) ApplyContiguous(r idMapRange, fn func(seg idMapIterator)) idMapGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return idMapGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return idMapGapIterator{} + } + } +} + +// +stateify savable +type idMapnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *idMapnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap idMapdynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [idMapmaxDegree - 1]idMapRange + values [idMapmaxDegree - 1]uint32 + children [idMapmaxDegree]*idMapnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *idMapnode) firstSegment() idMapIterator { + for n.hasChildren { + n = n.children[0] + } + return idMapIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *idMapnode) lastSegment() idMapIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return idMapIterator{n, n.nrSegments - 1} +} + +func (n *idMapnode) prevSibling() *idMapnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *idMapnode) nextSibling() *idMapnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *idMapnode) rebalanceBeforeInsert(gap idMapGapIterator) idMapGapIterator { + if n.nrSegments < idMapmaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:idMapminDegree-1], n.keys[:idMapminDegree-1]) + copy(left.values[:idMapminDegree-1], n.values[:idMapminDegree-1]) + copy(right.keys[:idMapminDegree-1], n.keys[idMapminDegree:]) + copy(right.values[:idMapminDegree-1], n.values[idMapminDegree:]) + n.keys[0], n.values[0] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1] + idMapzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:idMapminDegree], n.children[:idMapminDegree]) + copy(right.children[:idMapminDegree], n.children[idMapminDegree:]) + idMapzeroNodeSlice(n.children[2:]) + for i := 0; i < idMapminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if idMaptrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < idMapminDegree { + return idMapGapIterator{left, gap.index} + } + return idMapGapIterator{right, gap.index - idMapminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:idMapminDegree-1], n.keys[idMapminDegree:]) + copy(sibling.values[:idMapminDegree-1], n.values[idMapminDegree:]) + idMapzeroValueSlice(n.values[idMapminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:idMapminDegree], n.children[idMapminDegree:]) + idMapzeroNodeSlice(n.children[idMapminDegree:]) + for i := 0; i < idMapminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = idMapminDegree - 1 + + if idMaptrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < idMapminDegree { + return gap + } + return idMapGapIterator{sibling, gap.index - idMapminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *idMapnode) rebalanceAfterRemove(gap idMapGapIterator) idMapGapIterator { + for { + if n.nrSegments >= idMapminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if idMaptrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return idMapGapIterator{n, 0} + } + if gap.node == n { + return idMapGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if idMaptrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return idMapGapIterator{n, n.nrSegments} + } + return idMapGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return idMapGapIterator{p, gap.index} + } + if gap.node == right { + return idMapGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *idMapnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = idMapGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + idMapFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if idMaptrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *idMapnode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *idMapnode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *idMapnode) calculateMaxGapLeaf() uint32 { + max := idMapGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (idMapGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *idMapnode) calculateMaxGapInternal() uint32 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *idMapnode) searchFirstLargeEnoughGap(minSize uint32) idMapGapIterator { + if n.maxGap.Get() < minSize { + return idMapGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := idMapGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *idMapnode) searchLastLargeEnoughGap(minSize uint32) idMapGapIterator { + if n.maxGap.Get() < minSize { + return idMapGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := idMapGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type idMapIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *idMapnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg idMapIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg idMapIterator) Range() idMapRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg idMapIterator) Start() uint32 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg idMapIterator) End() uint32 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg idMapIterator) SetRangeUnchecked(r idMapRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetRange(r idMapRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg idMapIterator) SetStartUnchecked(start uint32) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetStart(start uint32) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg idMapIterator) SetEndUnchecked(end uint32) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetEnd(end uint32) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg idMapIterator) Value() uint32 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg idMapIterator) ValuePtr() *uint32 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg idMapIterator) SetValue(val uint32) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg idMapIterator) PrevSegment() idMapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return idMapIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return idMapIterator{} + } + return idMapsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg idMapIterator) NextSegment() idMapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return idMapIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return idMapIterator{} + } + return idMapsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg idMapIterator) PrevGap() idMapGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return idMapGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg idMapIterator) NextGap() idMapGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return idMapGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg idMapIterator) PrevNonEmpty() (idMapIterator, idMapGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return idMapIterator{}, gap + } + return gap.PrevSegment(), idMapGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg idMapIterator) NextNonEmpty() (idMapIterator, idMapGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return idMapIterator{}, gap + } + return gap.NextSegment(), idMapGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type idMapGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *idMapnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap idMapGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap idMapGapIterator) Range() idMapRange { + return idMapRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap idMapGapIterator) Start() uint32 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return idMapFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap idMapGapIterator) End() uint32 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return idMapFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap idMapGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap idMapGapIterator) PrevSegment() idMapIterator { + return idMapsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap idMapGapIterator) NextSegment() idMapIterator { + return idMapsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap idMapGapIterator) PrevGap() idMapGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return idMapGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap idMapGapIterator) NextGap() idMapGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return idMapGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap idMapGapIterator) NextLargeEnoughGap(minSize uint32) idMapGapIterator { + if idMaptrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap idMapGapIterator) nextLargeEnoughGapHelper(minSize uint32) idMapGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return idMapGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap idMapGapIterator) PrevLargeEnoughGap(minSize uint32) idMapGapIterator { + if idMaptrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap idMapGapIterator) prevLargeEnoughGapHelper(minSize uint32) idMapGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return idMapGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func idMapsegmentBeforePosition(n *idMapnode, i int) idMapIterator { + for i == 0 { + if n.parent == nil { + return idMapIterator{} + } + n, i = n.parent, n.parentIndex + } + return idMapIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func idMapsegmentAfterPosition(n *idMapnode, i int) idMapIterator { + for i == n.nrSegments { + if n.parent == nil { + return idMapIterator{} + } + n, i = n.parent, n.parentIndex + } + return idMapIterator{n, i} +} + +func idMapzeroValueSlice(slice []uint32) { + + for i := range slice { + idMapFunctions{}.ClearValue(&slice[i]) + } +} + +func idMapzeroNodeSlice(slice []*idMapnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *idMapSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *idMapnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *idMapnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if idMaptrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type idMapSegmentDataSlices struct { + Start []uint32 + End []uint32 + Values []uint32 +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *idMapSet) ExportSortedSlices() *idMapSegmentDataSlices { + var sds idMapSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *idMapSet) ImportSortedSlices(sds *idMapSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := idMapRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *idMapSet) segmentTestCheck(expectedSegments int, segFunc func(int, idMapRange, uint32) error) error { + havePrev := false + prev := uint32(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *idMapSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *idMapSet) saveRoot() *idMapSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *idMapSet) loadRoot(sds *idMapSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD deleted file mode 100644 index 9d26392c0..000000000 --- a/pkg/sentry/kernel/contexttest/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/contexttest", - "//pkg/sentry/kernel", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - ], -) diff --git a/pkg/sentry/kernel/contexttest/contexttest.go b/pkg/sentry/kernel/contexttest/contexttest.go deleted file mode 100644 index 22c340e56..000000000 --- a/pkg/sentry/kernel/contexttest/contexttest.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest provides a test context.Context which includes -// a dummy kernel pointing to a valid platform. -package contexttest - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform, and provides a stub kernel that only serves to point to -// the platform. -func Context(tb testing.TB) context.Context { - ctx := contexttest.Context(tb) - k := &kernel.Kernel{ - Platform: platform.FromContext(ctx), - } - k.SetMemoryFile(pgalloc.MemoryFileFromContext(ctx)) - ctx.(*contexttest.TestContext).RegisterValue(kernel.CtxKernel, k) - return ctx -} diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD deleted file mode 100644 index 723a85f64..000000000 --- a/pkg/sentry/kernel/epoll/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "epoll_list", - out = "epoll_list.go", - package = "epoll", - prefix = "pollEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*pollEntry", - "Linker": "*pollEntry", - }, -) - -go_library( - name = "epoll", - srcs = [ - "epoll.go", - "epoll_list.go", - "epoll_state.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/refs", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "epoll_test", - size = "small", - srcs = [ - "epoll_test.go", - ], - library = ":epoll", - deps = [ - "//pkg/sentry/contexttest", - "//pkg/sentry/fs/filetest", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/epoll/epoll_list.go b/pkg/sentry/kernel/epoll/epoll_list.go new file mode 100644 index 000000000..b6abe2de9 --- /dev/null +++ b/pkg/sentry/kernel/epoll/epoll_list.go @@ -0,0 +1,221 @@ +package epoll + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pollEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pollEntryElementMapper) linkerFor(elem *pollEntry) *pollEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pollEntryList struct { + head *pollEntry + tail *pollEntry +} + +// Reset resets list l to the empty state. +func (l *pollEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *pollEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *pollEntryList) Front() *pollEntry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *pollEntryList) Back() *pollEntry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *pollEntryList) Len() (count int) { + for e := l.Front(); e != nil; e = (pollEntryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *pollEntryList) PushFront(e *pollEntry) { + linker := pollEntryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + pollEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *pollEntryList) PushBack(e *pollEntry) { + linker := pollEntryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + pollEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *pollEntryList) PushBackList(m *pollEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pollEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pollEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *pollEntryList) InsertAfter(b, e *pollEntry) { + bLinker := pollEntryElementMapper{}.linkerFor(b) + eLinker := pollEntryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + pollEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *pollEntryList) InsertBefore(a, e *pollEntry) { + aLinker := pollEntryElementMapper{}.linkerFor(a) + eLinker := pollEntryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + pollEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *pollEntryList) Remove(e *pollEntry) { + linker := pollEntryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + pollEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + pollEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pollEntryEntry struct { + next *pollEntry + prev *pollEntry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *pollEntryEntry) Next() *pollEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *pollEntryEntry) Prev() *pollEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *pollEntryEntry) SetNext(elem *pollEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *pollEntryEntry) SetPrev(elem *pollEntry) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/epoll/epoll_state_autogen.go b/pkg/sentry/kernel/epoll/epoll_state_autogen.go new file mode 100644 index 000000000..44bd520ac --- /dev/null +++ b/pkg/sentry/kernel/epoll/epoll_state_autogen.go @@ -0,0 +1,192 @@ +// automatically generated by stateify. + +package epoll + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *FileIdentifier) StateTypeName() string { + return "pkg/sentry/kernel/epoll.FileIdentifier" +} + +func (f *FileIdentifier) StateFields() []string { + return []string{ + "File", + "Fd", + } +} + +func (f *FileIdentifier) beforeSave() {} + +// +checklocksignore +func (f *FileIdentifier) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.File) + stateSinkObject.Save(1, &f.Fd) +} + +func (f *FileIdentifier) afterLoad() {} + +// +checklocksignore +func (f *FileIdentifier) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &f.File) + stateSourceObject.Load(1, &f.Fd) +} + +func (p *pollEntry) StateTypeName() string { + return "pkg/sentry/kernel/epoll.pollEntry" +} + +func (p *pollEntry) StateFields() []string { + return []string{ + "pollEntryEntry", + "id", + "userData", + "mask", + "flags", + "epoll", + } +} + +func (p *pollEntry) beforeSave() {} + +// +checklocksignore +func (p *pollEntry) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.pollEntryEntry) + stateSinkObject.Save(1, &p.id) + stateSinkObject.Save(2, &p.userData) + stateSinkObject.Save(3, &p.mask) + stateSinkObject.Save(4, &p.flags) + stateSinkObject.Save(5, &p.epoll) +} + +// +checklocksignore +func (p *pollEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.pollEntryEntry) + stateSourceObject.LoadWait(1, &p.id) + stateSourceObject.Load(2, &p.userData) + stateSourceObject.Load(3, &p.mask) + stateSourceObject.Load(4, &p.flags) + stateSourceObject.Load(5, &p.epoll) + stateSourceObject.AfterLoad(p.afterLoad) +} + +func (e *EventPoll) StateTypeName() string { + return "pkg/sentry/kernel/epoll.EventPoll" +} + +func (e *EventPoll) StateFields() []string { + return []string{ + "files", + "readyList", + "waitingList", + "disabledList", + } +} + +func (e *EventPoll) beforeSave() {} + +// +checklocksignore +func (e *EventPoll) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + if !state.IsZeroValue(&e.FilePipeSeek) { + state.Failf("FilePipeSeek is %#v, expected zero", &e.FilePipeSeek) + } + if !state.IsZeroValue(&e.FileNotDirReaddir) { + state.Failf("FileNotDirReaddir is %#v, expected zero", &e.FileNotDirReaddir) + } + if !state.IsZeroValue(&e.FileNoFsync) { + state.Failf("FileNoFsync is %#v, expected zero", &e.FileNoFsync) + } + if !state.IsZeroValue(&e.FileNoopFlush) { + state.Failf("FileNoopFlush is %#v, expected zero", &e.FileNoopFlush) + } + if !state.IsZeroValue(&e.FileNoIoctl) { + state.Failf("FileNoIoctl is %#v, expected zero", &e.FileNoIoctl) + } + if !state.IsZeroValue(&e.FileNoMMap) { + state.Failf("FileNoMMap is %#v, expected zero", &e.FileNoMMap) + } + if !state.IsZeroValue(&e.Queue) { + state.Failf("Queue is %#v, expected zero", &e.Queue) + } + stateSinkObject.Save(0, &e.files) + stateSinkObject.Save(1, &e.readyList) + stateSinkObject.Save(2, &e.waitingList) + stateSinkObject.Save(3, &e.disabledList) +} + +// +checklocksignore +func (e *EventPoll) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.files) + stateSourceObject.Load(1, &e.readyList) + stateSourceObject.Load(2, &e.waitingList) + stateSourceObject.Load(3, &e.disabledList) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *pollEntryList) StateTypeName() string { + return "pkg/sentry/kernel/epoll.pollEntryList" +} + +func (l *pollEntryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *pollEntryList) beforeSave() {} + +// +checklocksignore +func (l *pollEntryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *pollEntryList) afterLoad() {} + +// +checklocksignore +func (l *pollEntryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *pollEntryEntry) StateTypeName() string { + return "pkg/sentry/kernel/epoll.pollEntryEntry" +} + +func (e *pollEntryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *pollEntryEntry) beforeSave() {} + +// +checklocksignore +func (e *pollEntryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *pollEntryEntry) afterLoad() {} + +// +checklocksignore +func (e *pollEntryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*FileIdentifier)(nil)) + state.Register((*pollEntry)(nil)) + state.Register((*EventPoll)(nil)) + state.Register((*pollEntryList)(nil)) + state.Register((*pollEntryEntry)(nil)) +} diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go deleted file mode 100644 index 8ef6cb3e7..000000000 --- a/pkg/sentry/kernel/epoll/epoll_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package epoll - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs/filetest" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestFileDestroyed(t *testing.T) { - f := filetest.NewTestFile(t) - id := FileIdentifier{f, 12} - - ctx := contexttest.Context(t) - efile := NewEventPoll(ctx) - e := efile.FileOperations.(*EventPoll) - if err := e.AddEntry(id, 0, waiter.ReadableEvents, [2]int32{}); err != nil { - t.Fatalf("addEntry failed: %v", err) - } - - // Check that we get an event reported twice in a row. - evt := e.ReadEvents(1) - if len(evt) != 1 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt)) - } - - evt = e.ReadEvents(1) - if len(evt) != 1 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt)) - } - - // Destroy the file. Check that we get no more events. - f.DecRef(ctx) - - evt = e.ReadEvents(1) - if len(evt) != 0 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 0, len(evt)) - } - -} diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD deleted file mode 100644 index f240a68aa..000000000 --- a/pkg/sentry/kernel/eventfd/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "eventfd", - srcs = ["eventfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fdnotifier", - "//pkg/hostarch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "eventfd_test", - size = "small", - srcs = ["eventfd_test.go"], - library = ":eventfd", - deps = [ - "//pkg/sentry/contexttest", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go new file mode 100644 index 000000000..596920c42 --- /dev/null +++ b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go @@ -0,0 +1,45 @@ +// automatically generated by stateify. + +package eventfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *EventOperations) StateTypeName() string { + return "pkg/sentry/kernel/eventfd.EventOperations" +} + +func (e *EventOperations) StateFields() []string { + return []string{ + "val", + "semMode", + "hostfd", + } +} + +func (e *EventOperations) beforeSave() {} + +// +checklocksignore +func (e *EventOperations) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + if !state.IsZeroValue(&e.wq) { + state.Failf("wq is %#v, expected zero", &e.wq) + } + stateSinkObject.Save(0, &e.val) + stateSinkObject.Save(1, &e.semMode) + stateSinkObject.Save(2, &e.hostfd) +} + +func (e *EventOperations) afterLoad() {} + +// +checklocksignore +func (e *EventOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.val) + stateSourceObject.Load(1, &e.semMode) + stateSourceObject.Load(2, &e.hostfd) +} + +func init() { + state.Register((*EventOperations)(nil)) +} diff --git a/pkg/sentry/kernel/eventfd/eventfd_test.go b/pkg/sentry/kernel/eventfd/eventfd_test.go deleted file mode 100644 index 1b9e60b3a..000000000 --- a/pkg/sentry/kernel/eventfd/eventfd_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eventfd - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestEventfd(t *testing.T) { - initVals := []uint64{ - 0, - // Using a non-zero initial value verifies that writing to an - // eventfd signals when the eventfd's counter was already - // non-zero. - 343, - } - - for _, initVal := range initVals { - ctx := contexttest.Context(t) - - // Make a new event that is writable. - event := New(ctx, initVal, false) - - // Register a callback for a write event. - w, ch := waiter.NewChannelEntry(nil) - event.EventRegister(&w, waiter.ReadableEvents) - defer event.EventUnregister(&w) - - data := []byte("00000124") - // Create and submit a write request. - n, err := event.Writev(ctx, usermem.BytesIOSequence(data)) - if err != nil { - t.Fatal(err) - } - if n != 8 { - t.Errorf("eventfd.write wrote %d bytes, not full int64", n) - } - - // Check if the callback fired due to the write event. - select { - case <-ch: - default: - t.Errorf("Didn't get notified of EventIn after write") - } - } -} - -func TestEventfdStat(t *testing.T) { - ctx := contexttest.Context(t) - - // Make a new event that is writable. - event := New(ctx, 0, false) - - // Create and submit an stat request. - uattr, err := event.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - t.Fatalf("eventfd stat request failed: %v", err) - } - if uattr.Size != 0 { - t.Fatal("EventFD size should be 0") - } -} diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD deleted file mode 100644 index 6b2dd09da..000000000 --- a/pkg/sentry/kernel/fasync/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fasync", - srcs = ["fasync.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/errors/linuxerr", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/fasync/fasync_state_autogen.go b/pkg/sentry/kernel/fasync/fasync_state_autogen.go new file mode 100644 index 000000000..8ec069b92 --- /dev/null +++ b/pkg/sentry/kernel/fasync/fasync_state_autogen.go @@ -0,0 +1,57 @@ +// automatically generated by stateify. + +package fasync + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (a *FileAsync) StateTypeName() string { + return "pkg/sentry/kernel/fasync.FileAsync" +} + +func (a *FileAsync) StateFields() []string { + return []string{ + "e", + "fd", + "requester", + "registered", + "signal", + "recipientPG", + "recipientTG", + "recipientT", + } +} + +func (a *FileAsync) beforeSave() {} + +// +checklocksignore +func (a *FileAsync) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.e) + stateSinkObject.Save(1, &a.fd) + stateSinkObject.Save(2, &a.requester) + stateSinkObject.Save(3, &a.registered) + stateSinkObject.Save(4, &a.signal) + stateSinkObject.Save(5, &a.recipientPG) + stateSinkObject.Save(6, &a.recipientTG) + stateSinkObject.Save(7, &a.recipientT) +} + +func (a *FileAsync) afterLoad() {} + +// +checklocksignore +func (a *FileAsync) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.e) + stateSourceObject.Load(1, &a.fd) + stateSourceObject.Load(2, &a.requester) + stateSourceObject.Load(3, &a.registered) + stateSourceObject.Load(4, &a.signal) + stateSourceObject.Load(5, &a.recipientPG) + stateSourceObject.Load(6, &a.recipientTG) + stateSourceObject.Load(7, &a.recipientT) +} + +func init() { + state.Register((*FileAsync)(nil)) +} diff --git a/pkg/sentry/kernel/fd_table_refs.go b/pkg/sentry/kernel/fd_table_refs.go new file mode 100644 index 000000000..9f3bcc07a --- /dev/null +++ b/pkg/sentry/kernel/fd_table_refs.go @@ -0,0 +1,140 @@ +package kernel + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const FDTableenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var FDTableobj *FDTable + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type FDTableRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *FDTableRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *FDTableRefs) RefType() string { + return fmt.Sprintf("%T", FDTableobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *FDTableRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *FDTableRefs) LogRefs() bool { + return FDTableenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *FDTableRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *FDTableRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if FDTableenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *FDTableRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if FDTableenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *FDTableRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if FDTableenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *FDTableRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go deleted file mode 100644 index bf5460083..000000000 --- a/pkg/sentry/kernel/fd_table_test.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/filetest" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sync" -) - -const ( - // maxFD is the maximum FD to try to create in the map. - // - // This number of open files has been seen in the wild. - maxFD = 2 * 1024 -) - -func runTest(t testing.TB, fn func(ctx context.Context, fdTable *FDTable, file *fs.File, limitSet *limits.LimitSet)) { - t.Helper() // Don't show in stacks. - - // Create the limits and context. - limitSet := limits.NewLimitSet() - limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true) - ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet) - - // Create a test file.; - file := filetest.NewTestFile(t) - - // Create the table. - fdTable := new(FDTable) - fdTable.init() - - // Run the test. - fn(ctx, fdTable, file, limitSet) -} - -// TestFDTableMany allocates maxFD FDs, i.e. maxes out the FDTable, until there -// is no room, then makes sure that NewFDAt works and also that if we remove -// one and add one that works too. -func TestFDTableMany(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - for i := 0; i < maxFD; i++ { - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Allocated %v FDs but wanted to allocate %v", i, maxFD) - } - } - - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err == nil { - t.Fatalf("fdTable.NewFDs(0, r) in full map: got nil, wanted error") - } - - if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { - t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - i := int32(2) - fdTable.Remove(ctx, i) - if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { - t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) - } - }) -} - -func TestFDTableOverLimit(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { - t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") - } - - if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { - t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") - } - - if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { - t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) - } else { - for _, fd := range fds { - fdTable.Remove(ctx, fd) - } - } - - if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { - t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) - } else if len(fds) != 1 || fds[0] != 0 { - t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) - } - }) -} - -// TestFDTable does a set of simple tests to make sure simple adds, removes, -// GetRefs, and DecRefs work. The ordering is just weird enough that a -// table-driven approach seemed clumsy. -func TestFDTable(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, limitSet *limits.LimitSet) { - // Cap the limit at one. - limitSet.Set(limits.NumberOfFiles, limits.Limit{1, maxFD}, true) - - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Adding an FD to an empty 1-size map: got %v, want nil", err) - } - - if _, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err == nil { - t.Fatalf("Adding an FD to a filled 1-size map: got nil, wanted an error") - } - - // Remove the previous limit. - limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true) - - if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { - t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) - } else if len(fds) != 1 || fds[0] != 1 { - t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) - } - - if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { - t.Fatalf("Replacing FD 1 via fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - if err := fdTable.NewFDAt(ctx, maxFD+1, file, FDFlags{}); err == nil { - t.Fatalf("Using an FD that was too large via fdTable.NewFDAt(%v, r, FDFlags{}): got nil, wanted an error", maxFD+1) - } - - if ref, _ := fdTable.Get(1); ref == nil { - t.Fatalf("fdTable.Get(1): got nil, wanted %v", file) - } - - if ref, _ := fdTable.Get(2); ref != nil { - t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref) - } - - ref, _ := fdTable.Remove(ctx, 1) - if ref == nil { - t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success") - } - ref.DecRef(ctx) - - if ref, _ := fdTable.Remove(ctx, 1); ref != nil { - t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") - } - }) -} - -func TestDescriptorFlags(t *testing.T) { - runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - if err := fdTable.NewFDAt(ctx, 2, file, FDFlags{CloseOnExec: true}); err != nil { - t.Fatalf("fdTable.NewFDAt(2, r, FDFlags{}): got %v, wanted nil", err) - } - - newFile, flags := fdTable.Get(2) - if newFile == nil { - t.Fatalf("fdTable.Get(2): got a %v, wanted nil", newFile) - } - - if !flags.CloseOnExec { - t.Fatalf("new File flags %v don't match original %d\n", flags, 0) - } - }) -} - -func BenchmarkFDLookupAndDecRef(b *testing.B) { - b.StopTimer() // Setup. - - runTest(b, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file, file, file, file, file}, FDFlags{}) - if err != nil { - b.Fatalf("fdTable.NewFDs: got %v, wanted nil", err) - } - - b.StartTimer() // Benchmark. - for i := 0; i < b.N; i++ { - tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef(ctx) - } - }) -} - -func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) { - b.StopTimer() // Setup. - - runTest(b, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { - fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file, file, file, file, file}, FDFlags{}) - if err != nil { - b.Fatalf("fdTable.NewFDs: got %v, wanted nil", err) - } - - concurrency := runtime.GOMAXPROCS(0) - if concurrency < 4 { - concurrency = 4 - } - each := b.N / concurrency - - b.StartTimer() // Benchmark. - var wg sync.WaitGroup - for i := 0; i < concurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < each; i++ { - tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef(ctx) - } - }() - } - wg.Wait() - }) -} diff --git a/pkg/sentry/kernel/fs_context_refs.go b/pkg/sentry/kernel/fs_context_refs.go new file mode 100644 index 000000000..8cbe70d1f --- /dev/null +++ b/pkg/sentry/kernel/fs_context_refs.go @@ -0,0 +1,140 @@ +package kernel + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const FSContextenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var FSContextobj *FSContext + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type FSContextRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *FSContextRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *FSContextRefs) RefType() string { + return fmt.Sprintf("%T", FSContextobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *FSContextRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *FSContextRefs) LogRefs() bool { + return FSContextenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *FSContextRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *FSContextRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if FSContextenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *FSContextRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if FSContextenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *FSContextRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if FSContextenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *FSContextRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD deleted file mode 100644 index c897e3a5f..000000000 --- a/pkg/sentry/kernel/futex/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_bucket", - out = "atomicptr_bucket_unsafe.go", - package = "futex", - suffix = "Bucket", - template = "//pkg/sync/atomicptr:generic_atomicptr", - types = { - "Value": "bucket", - }, -) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "futex", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Waiter", - "Linker": "*Waiter", - }, -) - -go_library( - name = "futex", - srcs = [ - "atomicptr_bucket_unsafe.go", - "futex.go", - "waiter_list.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/sentry/memmap", - "//pkg/sync", - "//pkg/usermem", - ], -) - -go_test( - name = "futex_test", - size = "small", - srcs = ["futex_test.go"], - library = ":futex", - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sync", - ], -) diff --git a/pkg/sentry/kernel/futex/atomicptr_bucket_unsafe.go b/pkg/sentry/kernel/futex/atomicptr_bucket_unsafe.go new file mode 100644 index 000000000..d3fdf09b0 --- /dev/null +++ b/pkg/sentry/kernel/futex/atomicptr_bucket_unsafe.go @@ -0,0 +1,37 @@ +package futex + +import ( + "sync/atomic" + "unsafe" +) + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtrBucket struct { + ptr unsafe.Pointer `state:".(*bucket)"` +} + +func (p *AtomicPtrBucket) savePtr() *bucket { + return p.Load() +} + +func (p *AtomicPtrBucket) loadPtr(v *bucket) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtrBucket) Load() *bucket { + return (*bucket)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtrBucket) Store(x *bucket) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sentry/kernel/futex/futex_state_autogen.go b/pkg/sentry/kernel/futex/futex_state_autogen.go new file mode 100644 index 000000000..6e22b313a --- /dev/null +++ b/pkg/sentry/kernel/futex/futex_state_autogen.go @@ -0,0 +1,122 @@ +// automatically generated by stateify. + +package futex + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (b *bucket) StateTypeName() string { + return "pkg/sentry/kernel/futex.bucket" +} + +func (b *bucket) StateFields() []string { + return []string{} +} + +func (b *bucket) beforeSave() {} + +// +checklocksignore +func (b *bucket) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + if !state.IsZeroValue(&b.waiters) { + state.Failf("waiters is %#v, expected zero", &b.waiters) + } +} + +func (b *bucket) afterLoad() {} + +// +checklocksignore +func (b *bucket) StateLoad(stateSourceObject state.Source) { +} + +func (m *Manager) StateTypeName() string { + return "pkg/sentry/kernel/futex.Manager" +} + +func (m *Manager) StateFields() []string { + return []string{ + "sharedBucket", + } +} + +func (m *Manager) beforeSave() {} + +// +checklocksignore +func (m *Manager) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + if !state.IsZeroValue(&m.privateBuckets) { + state.Failf("privateBuckets is %#v, expected zero", &m.privateBuckets) + } + stateSinkObject.Save(0, &m.sharedBucket) +} + +func (m *Manager) afterLoad() {} + +// +checklocksignore +func (m *Manager) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.sharedBucket) +} + +func (l *waiterList) StateTypeName() string { + return "pkg/sentry/kernel/futex.waiterList" +} + +func (l *waiterList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *waiterList) beforeSave() {} + +// +checklocksignore +func (l *waiterList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *waiterList) afterLoad() {} + +// +checklocksignore +func (l *waiterList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *waiterEntry) StateTypeName() string { + return "pkg/sentry/kernel/futex.waiterEntry" +} + +func (e *waiterEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *waiterEntry) beforeSave() {} + +// +checklocksignore +func (e *waiterEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *waiterEntry) afterLoad() {} + +// +checklocksignore +func (e *waiterEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*bucket)(nil)) + state.Register((*Manager)(nil)) + state.Register((*waiterList)(nil)) + state.Register((*waiterEntry)(nil)) +} diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go deleted file mode 100644 index 04c136f87..000000000 --- a/pkg/sentry/kernel/futex/futex_test.go +++ /dev/null @@ -1,536 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package futex - -import ( - "math" - "runtime" - "sync/atomic" - "testing" - "unsafe" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sync" -) - -// testData implements the Target interface, and allows us to -// treat the address passed for futex operations as an index in -// a byte slice for testing simplicity. -type testData struct { - context.Context - data []byte -} - -const sizeofInt32 = 4 - -func newTestData(size uint) testData { - return testData{ - data: make([]byte, size), - } -} - -func (t testData) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) { - val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new) - return val, nil -} - -func (t testData) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) { - if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) { - return old, nil - } - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil -} - -func (t testData) LoadUint32(addr hostarch.Addr) (uint32, error) { - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil -} - -func (t testData) GetSharedKey(addr hostarch.Addr) (Key, error) { - return Key{ - Kind: KindSharedMappable, - Offset: uint64(addr), - }, nil -} - -func futexKind(private bool) string { - if private { - return "private" - } - return "shared" -} - -func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr hostarch.Addr, private bool, val uint32, bitmask uint32) *Waiter { - w := NewWaiter() - if err := m.WaitPrepare(w, ta, addr, private, val, bitmask); err != nil { - t.Fatalf("WaitPrepare failed: %v", err) - } - return w -} - -func TestFutexWake(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start waiting for wakeup. - w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w, d) - - // Perform a wakeup. - if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 { - t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect the waiter to have been woken. - if !w.woken() { - t.Error("waiter not woken") - } - }) - } -} - -func TestFutexWakeBitmask(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start waiting for wakeup. - w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff) - defer m.WaitComplete(w, d) - - // Perform a wakeup using the wrong bitmask. - if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 { - t.Errorf("Wake with non-matching bitmask: got (%d, %v), wanted (0, nil)", n, err) - } - - // Expect the waiter to still be waiting. - if w.woken() { - t.Error("waiter woken unexpectedly") - } - - // Perform a wakeup using the right bitmask. - if n, err := m.Wake(d, 0, private, 0x00000001, 1); err != nil || n != 1 { - t.Errorf("Wake with matching bitmask: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect that the waiter was woken. - if !w.woken() { - t.Error("waiter not woken") - } - }) - } -} - -func TestFutexWakeTwo(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start three waiters waiting for wakeup. - var ws [3]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i], d) - } - - // Perform two wakeups. - const wakeups = 2 - if n, err := m.Wake(d, 0, private, ^uint32(0), 2); err != nil || n != wakeups { - t.Errorf("Wake: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly two waiters were woken. - // We don't get guarantees about exactly which two, - // (although we expect them to be w1 and w2). - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -func TestFutexWakeUnrelated(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(2 * sizeofInt32) - - // Start two waiters waiting for wakeup on different addresses. - w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1, d) - w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2, d) - - // Perform two wakeups on the second address. - if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 { - t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect that only the second waiter was woken. - if w1.woken() { - t.Error("w1 woken unexpectedly") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(2 * sizeofInt32) - - // Perform wakeups with no waiters. - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 0 { - t.Fatalf("WakeOp: got (%d, %v), wanted (0, nil)", n, err) - } - }) - } -} - -func TestWakeOpFirstNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1, d) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2, d) - - // Perform 10 wakeups on address 0. - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that both waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpSecondNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address sizeofInt32. - w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1, d) - w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2, d) - - // Perform 10 wakeups on address sizeofInt32 (contingent on - // d.Op(0), which should succeed). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 0); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that both waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address sizeofInt32. - w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1, d) - w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2, d) - - // Perform 10 wakeups on address sizeofInt32 (contingent on - // d.Op(1), which should fail). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 1); err != nil || n != 0 { - t.Errorf("WakeOp: got (%d, %v), wanted (0, nil)", n, err) - } - - // Expect that neither waiter was woken. - if w1.woken() { - t.Error("w1 woken unexpectedly") - } - if w2.woken() { - t.Error("w2 woken unexpectedly") - } - }) - } -} - -func TestWakeOpAllNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1, d) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2, d) - - // Add two waiters on address sizeofInt32. - w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3, d) - w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4, d) - - // Perform 10 wakeups on address 0 (unconditionally), and 10 - // wakeups on address sizeofInt32 (contingent on d.Op(0), which - // should succeed). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 4 { - t.Errorf("WakeOp: got (%d, %v), wanted (4, nil)", n, err) - } - - // Expect that all waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - if !w3.woken() { - t.Error("w3 not woken") - } - if !w4.woken() { - t.Error("w4 not woken") - } - }) - } -} - -func TestWakeOpAllNonEmptyFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1, d) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2, d) - - // Add two waiters on address sizeofInt32. - w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3, d) - w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4, d) - - // Perform 10 wakeups on address 0 (unconditionally), and 10 - // wakeups on address sizeofInt32 (contingent on d.Op(1), which - // should fail). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 1); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that only the first two waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - if w3.woken() { - t.Error("w3 woken unexpectedly") - } - if w4.woken() { - t.Error("w4 woken unexpectedly") - } - }) - } -} - -func TestWakeOpSameAddress(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add four waiters on address 0. - var ws [4]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i], d) - } - - // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup - // on address 0 (contingent on d.Op(0), which should succeed). - const wakeups = 2 - if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 0); err != nil || n != wakeups { - t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly two waiters were woken. - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -func TestWakeOpSameAddressFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add four waiters on address 0. - var ws [4]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i], d) - } - - // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup - // on address 0 (contingent on d.Op(1), which should fail). - const wakeups = 1 - if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 1); err != nil || n != wakeups { - t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly one waiter was woken. - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -const ( - testMutexSize = sizeofInt32 - testMutexLocked uint32 = 1 - testMutexUnlocked uint32 = 0 -) - -// testMutex ties together a testData slice, an address, and a -// futex manager in order to implement the sync.Locker interface. -// Beyond being used as a Locker, this is a simple mechanism for -// changing the underlying values for simpler tests. -type testMutex struct { - a hostarch.Addr - d testData - m *Manager -} - -func newTestMutex(addr hostarch.Addr, d testData, m *Manager) *testMutex { - return &testMutex{a: addr, d: d, m: m} -} - -// Lock acquires the testMutex. -// This may wait for it to be available via the futex manager. -func (t *testMutex) Lock() { - for { - // Attempt to grab the lock. - if atomic.CompareAndSwapUint32( - (*uint32)(unsafe.Pointer(&t.d.data[t.a])), - testMutexUnlocked, - testMutexLocked) { - // Lock held. - return - } - - // Wait for it to be "not locked". - w := NewWaiter() - err := t.m.WaitPrepare(w, t.d, t.a, true, testMutexLocked, ^uint32(0)) - if linuxerr.Equals(linuxerr.EAGAIN, err) { - continue - } - if err != nil { - // Should never happen. - panic("WaitPrepare returned unexpected error: " + err.Error()) - } - <-w.C - t.m.WaitComplete(w, t.d) - } -} - -// Unlock releases the testMutex. -// This will notify any waiters via the futex manager. -func (t *testMutex) Unlock() { - // Unlock. - atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked) - - // Notify all waiters. - t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32) -} - -// This function was shamelessly stolen from mutex_test.go. -func HammerMutex(l sync.Locker, loops int, cdone chan bool) { - for i := 0; i < loops; i++ { - l.Lock() - runtime.Gosched() - l.Unlock() - } - cdone <- true -} - -func TestMutexStress(t *testing.T) { - m := NewManager() - d := newTestData(testMutexSize) - tm := newTestMutex(0*testMutexSize, d, m) - c := make(chan bool) - - for i := 0; i < 10; i++ { - go HammerMutex(tm, 1000, c) - } - - for i := 0; i < 10; i++ { - <-c - } -} diff --git a/pkg/sentry/kernel/futex/futex_unsafe_state_autogen.go b/pkg/sentry/kernel/futex/futex_unsafe_state_autogen.go new file mode 100644 index 000000000..b06acb209 --- /dev/null +++ b/pkg/sentry/kernel/futex/futex_unsafe_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package futex + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *AtomicPtrBucket) StateTypeName() string { + return "pkg/sentry/kernel/futex.AtomicPtrBucket" +} + +func (p *AtomicPtrBucket) StateFields() []string { + return []string{ + "ptr", + } +} + +func (p *AtomicPtrBucket) beforeSave() {} + +// +checklocksignore +func (p *AtomicPtrBucket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var ptrValue *bucket + ptrValue = p.savePtr() + stateSinkObject.SaveValue(0, ptrValue) +} + +func (p *AtomicPtrBucket) afterLoad() {} + +// +checklocksignore +func (p *AtomicPtrBucket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*bucket), func(y interface{}) { p.loadPtr(y.(*bucket)) }) +} + +func init() { + state.Register((*AtomicPtrBucket)(nil)) +} diff --git a/pkg/sentry/kernel/futex/waiter_list.go b/pkg/sentry/kernel/futex/waiter_list.go new file mode 100644 index 000000000..24968ce4b --- /dev/null +++ b/pkg/sentry/kernel/futex/waiter_list.go @@ -0,0 +1,221 @@ +package futex + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Waiter) *Waiter { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Waiter + tail *Waiter +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *waiterList) Front() *Waiter { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *waiterList) Back() *Waiter { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = (waiterElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *waiterList) PushFront(e *Waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *waiterList) PushBack(e *Waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *waiterList) InsertAfter(b, e *Waiter) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *waiterList) InsertBefore(a, e *Waiter) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *waiterList) Remove(e *Waiter) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Waiter + prev *Waiter +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *waiterEntry) Next() *Waiter { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *waiterEntry) Prev() *Waiter { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *waiterEntry) SetNext(elem *Waiter) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *waiterEntry) SetPrev(elem *Waiter) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/g3doc/run_states.dot b/pkg/sentry/kernel/g3doc/run_states.dot deleted file mode 100644 index 7861fe1f5..000000000 --- a/pkg/sentry/kernel/g3doc/run_states.dot +++ /dev/null @@ -1,99 +0,0 @@ -digraph { - subgraph { - App; - } - subgraph { - Interrupt; - InterruptAfterSignalDeliveryStop; - } - subgraph { - Syscall; - SyscallAfterPtraceEventSeccomp; - SyscallEnter; - SyscallAfterSyscallEnterStop; - SyscallAfterSysemuStop; - SyscallInvoke; - SyscallAfterPtraceEventClone; - SyscallAfterExecStop; - SyscallAfterVforkStop; - SyscallReinvoke; - SyscallExit; - } - subgraph { - Vsyscall; - VsyscallAfterPtraceEventSeccomp; - VsyscallInvoke; - } - subgraph { - Exit; - ExitMain; // leave thread group, release resources, reparent children, kill PID namespace and wait if TGID 1 - ExitNotify; // signal parent/tracer, become waitable - ExitDone; // represented by t.runState == nil - } - - // Task exit - Exit -> ExitMain; - ExitMain -> ExitNotify; - ExitNotify -> ExitDone; - - // Execution of untrusted application code - App -> App; - - // Interrupts (usually signal delivery) - App -> Interrupt; - Interrupt -> Interrupt; // if other interrupt conditions may still apply - Interrupt -> Exit; // if killed - - // Syscalls - App -> Syscall; - Syscall -> SyscallEnter; - SyscallEnter -> SyscallInvoke; - SyscallInvoke -> SyscallExit; - SyscallExit -> App; - - // exit, exit_group - SyscallInvoke -> Exit; - - // execve - SyscallInvoke -> SyscallAfterExecStop; - SyscallAfterExecStop -> SyscallExit; - SyscallAfterExecStop -> App; // fatal signal pending - - // vfork - SyscallInvoke -> SyscallAfterVforkStop; - SyscallAfterVforkStop -> SyscallExit; - - // Vsyscalls - App -> Vsyscall; - Vsyscall -> VsyscallInvoke; - Vsyscall -> App; // fault while reading return address from stack - VsyscallInvoke -> App; - - // ptrace-specific branches - Interrupt -> InterruptAfterSignalDeliveryStop; - InterruptAfterSignalDeliveryStop -> Interrupt; - SyscallEnter -> SyscallAfterSyscallEnterStop; - SyscallAfterSyscallEnterStop -> SyscallInvoke; - SyscallAfterSyscallEnterStop -> SyscallExit; // skipped by tracer - SyscallAfterSyscallEnterStop -> App; // fatal signal pending - SyscallEnter -> SyscallAfterSysemuStop; - SyscallAfterSysemuStop -> SyscallExit; - SyscallAfterSysemuStop -> App; // fatal signal pending - SyscallInvoke -> SyscallAfterPtraceEventClone; - SyscallAfterPtraceEventClone -> SyscallExit; - SyscallAfterPtraceEventClone -> SyscallAfterVforkStop; - - // seccomp - Syscall -> App; // SECCOMP_RET_TRAP, SECCOMP_RET_ERRNO, SECCOMP_RET_KILL, SECCOMP_RET_TRACE without tracer - Syscall -> SyscallAfterPtraceEventSeccomp; // SECCOMP_RET_TRACE - SyscallAfterPtraceEventSeccomp -> SyscallEnter; - SyscallAfterPtraceEventSeccomp -> SyscallExit; // skipped by tracer - SyscallAfterPtraceEventSeccomp -> App; // fatal signal pending - Vsyscall -> VsyscallAfterPtraceEventSeccomp; - VsyscallAfterPtraceEventSeccomp -> VsyscallInvoke; - VsyscallAfterPtraceEventSeccomp -> App; - - // Autosave - SyscallInvoke -> SyscallReinvoke; - SyscallReinvoke -> SyscallInvoke; -} diff --git a/pkg/sentry/kernel/g3doc/run_states.png b/pkg/sentry/kernel/g3doc/run_states.png Binary files differdeleted file mode 100644 index b63b60f02..000000000 --- a/pkg/sentry/kernel/g3doc/run_states.png +++ /dev/null diff --git a/pkg/sentry/kernel/ipc/BUILD b/pkg/sentry/kernel/ipc/BUILD deleted file mode 100644 index e42a94e15..000000000 --- a/pkg/sentry/kernel/ipc/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "ipc", - srcs = [ - "object.go", - "registry.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - ], -) diff --git a/pkg/sentry/kernel/ipc/ipc_state_autogen.go b/pkg/sentry/kernel/ipc/ipc_state_autogen.go new file mode 100644 index 000000000..b74f23a21 --- /dev/null +++ b/pkg/sentry/kernel/ipc/ipc_state_autogen.go @@ -0,0 +1,86 @@ +// automatically generated by stateify. + +package ipc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (o *Object) StateTypeName() string { + return "pkg/sentry/kernel/ipc.Object" +} + +func (o *Object) StateFields() []string { + return []string{ + "UserNS", + "ID", + "Key", + "Creator", + "Owner", + "Perms", + } +} + +func (o *Object) beforeSave() {} + +// +checklocksignore +func (o *Object) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.UserNS) + stateSinkObject.Save(1, &o.ID) + stateSinkObject.Save(2, &o.Key) + stateSinkObject.Save(3, &o.Creator) + stateSinkObject.Save(4, &o.Owner) + stateSinkObject.Save(5, &o.Perms) +} + +func (o *Object) afterLoad() {} + +// +checklocksignore +func (o *Object) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.UserNS) + stateSourceObject.Load(1, &o.ID) + stateSourceObject.Load(2, &o.Key) + stateSourceObject.Load(3, &o.Creator) + stateSourceObject.Load(4, &o.Owner) + stateSourceObject.Load(5, &o.Perms) +} + +func (r *Registry) StateTypeName() string { + return "pkg/sentry/kernel/ipc.Registry" +} + +func (r *Registry) StateFields() []string { + return []string{ + "UserNS", + "objects", + "keysToIDs", + "lastIDUsed", + } +} + +func (r *Registry) beforeSave() {} + +// +checklocksignore +func (r *Registry) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.UserNS) + stateSinkObject.Save(1, &r.objects) + stateSinkObject.Save(2, &r.keysToIDs) + stateSinkObject.Save(3, &r.lastIDUsed) +} + +func (r *Registry) afterLoad() {} + +// +checklocksignore +func (r *Registry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.UserNS) + stateSourceObject.Load(1, &r.objects) + stateSourceObject.Load(2, &r.keysToIDs) + stateSourceObject.Load(3, &r.lastIDUsed) +} + +func init() { + state.Register((*Object)(nil)) + state.Register((*Registry)(nil)) +} diff --git a/pkg/sentry/kernel/ipc_namespace_refs.go b/pkg/sentry/kernel/ipc_namespace_refs.go new file mode 100644 index 000000000..1a4c31bb0 --- /dev/null +++ b/pkg/sentry/kernel/ipc_namespace_refs.go @@ -0,0 +1,140 @@ +package kernel + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const IPCNamespaceenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var IPCNamespaceobj *IPCNamespace + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type IPCNamespaceRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *IPCNamespaceRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *IPCNamespaceRefs) RefType() string { + return fmt.Sprintf("%T", IPCNamespaceobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *IPCNamespaceRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *IPCNamespaceRefs) LogRefs() bool { + return IPCNamespaceenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *IPCNamespaceRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *IPCNamespaceRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if IPCNamespaceenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *IPCNamespaceRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if IPCNamespaceenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *IPCNamespaceRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if IPCNamespaceenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *IPCNamespaceRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/kernel/kernel_abi_autogen_unsafe.go b/pkg/sentry/kernel/kernel_abi_autogen_unsafe.go new file mode 100644 index 000000000..83ed0f8ac --- /dev/null +++ b/pkg/sentry/kernel/kernel_abi_autogen_unsafe.go @@ -0,0 +1,224 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package kernel + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*ThreadID)(nil) +var _ marshal.Marshallable = (*vdsoParams)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +//go:nosplit +func (tid *ThreadID) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (tid *ThreadID) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(*tid)) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (tid *ThreadID) UnmarshalBytes(src []byte) { + *tid = ThreadID(int32(hostarch.ByteOrder.Uint32(src[:4]))) +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (tid *ThreadID) Packed() bool { + // Scalar newtypes are always packed. + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (tid *ThreadID) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(tid), uintptr(tid.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (tid *ThreadID) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(tid), unsafe.Pointer(&src[0]), uintptr(tid.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (tid *ThreadID) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tid))) + hdr.Len = tid.SizeBytes() + hdr.Cap = tid.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that tid + // must live until the use above. + runtime.KeepAlive(tid) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (tid *ThreadID) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return tid.CopyOutN(cc, addr, tid.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (tid *ThreadID) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tid))) + hdr.Len = tid.SizeBytes() + hdr.Cap = tid.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that tid + // must live until the use above. + runtime.KeepAlive(tid) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (tid *ThreadID) WriteTo(w io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(tid))) + hdr.Len = tid.SizeBytes() + hdr.Cap = tid.SizeBytes() + + length, err := w.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that tid + // must live until the use above. + runtime.KeepAlive(tid) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (v *vdsoParams) SizeBytes() int { + return 64 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (v *vdsoParams) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.monotonicReady)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.monotonicBaseCycles)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.monotonicBaseRef)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.monotonicFrequency)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.realtimeReady)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.realtimeBaseCycles)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.realtimeBaseRef)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(v.realtimeFrequency)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (v *vdsoParams) UnmarshalBytes(src []byte) { + v.monotonicReady = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.monotonicBaseCycles = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.monotonicBaseRef = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.monotonicFrequency = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.realtimeReady = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.realtimeBaseCycles = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.realtimeBaseRef = int64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + v.realtimeFrequency = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (v *vdsoParams) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (v *vdsoParams) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(v), uintptr(v.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (v *vdsoParams) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(v), unsafe.Pointer(&src[0]), uintptr(v.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (v *vdsoParams) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(v))) + hdr.Len = v.SizeBytes() + hdr.Cap = v.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that v + // must live until the use above. + runtime.KeepAlive(v) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (v *vdsoParams) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return v.CopyOutN(cc, addr, v.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (v *vdsoParams) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(v))) + hdr.Len = v.SizeBytes() + hdr.Cap = v.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that v + // must live until the use above. + runtime.KeepAlive(v) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (v *vdsoParams) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(v))) + hdr.Len = v.SizeBytes() + hdr.Cap = v.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that v + // must live until the use above. + runtime.KeepAlive(v) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/kernel/kernel_amd64_abi_autogen_unsafe.go b/pkg/sentry/kernel/kernel_amd64_abi_autogen_unsafe.go new file mode 100644 index 000000000..988351acf --- /dev/null +++ b/pkg/sentry/kernel/kernel_amd64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build amd64 +// +build amd64 + +package kernel + +import ( +) + diff --git a/pkg/sentry/kernel/kernel_amd64_state_autogen.go b/pkg/sentry/kernel/kernel_amd64_state_autogen.go new file mode 100644 index 000000000..859999676 --- /dev/null +++ b/pkg/sentry/kernel/kernel_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package kernel diff --git a/pkg/sentry/kernel/kernel_arm64_abi_autogen_unsafe.go b/pkg/sentry/kernel/kernel_arm64_abi_autogen_unsafe.go new file mode 100644 index 000000000..823d08b19 --- /dev/null +++ b/pkg/sentry/kernel/kernel_arm64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build arm64 +// +build arm64 + +package kernel + +import ( +) + diff --git a/pkg/sentry/kernel/kernel_arm64_state_autogen.go b/pkg/sentry/kernel/kernel_arm64_state_autogen.go new file mode 100644 index 000000000..7340674a6 --- /dev/null +++ b/pkg/sentry/kernel/kernel_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package kernel diff --git a/pkg/sentry/kernel/kernel_opts_abi_autogen_unsafe.go b/pkg/sentry/kernel/kernel_opts_abi_autogen_unsafe.go new file mode 100644 index 000000000..dcca4b9dd --- /dev/null +++ b/pkg/sentry/kernel/kernel_opts_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build go1.1 +// +build go1.1 + +package kernel + +import ( +) + diff --git a/pkg/sentry/kernel/kernel_opts_state_autogen.go b/pkg/sentry/kernel/kernel_opts_state_autogen.go new file mode 100644 index 000000000..d878d9a06 --- /dev/null +++ b/pkg/sentry/kernel/kernel_opts_state_autogen.go @@ -0,0 +1,35 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package kernel + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *SpecialOpts) StateTypeName() string { + return "pkg/sentry/kernel.SpecialOpts" +} + +func (s *SpecialOpts) StateFields() []string { + return []string{} +} + +func (s *SpecialOpts) beforeSave() {} + +// +checklocksignore +func (s *SpecialOpts) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +func (s *SpecialOpts) afterLoad() {} + +// +checklocksignore +func (s *SpecialOpts) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*SpecialOpts)(nil)) +} diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go new file mode 100644 index 000000000..d48c96425 --- /dev/null +++ b/pkg/sentry/kernel/kernel_state_autogen.go @@ -0,0 +1,2604 @@ +// automatically generated by stateify. + +package kernel + +import ( + "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/sentry/device" + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip" +) + +func (a *abstractEndpoint) StateTypeName() string { + return "pkg/sentry/kernel.abstractEndpoint" +} + +func (a *abstractEndpoint) StateFields() []string { + return []string{ + "ep", + "socket", + "name", + "ns", + } +} + +func (a *abstractEndpoint) beforeSave() {} + +// +checklocksignore +func (a *abstractEndpoint) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.ep) + stateSinkObject.Save(1, &a.socket) + stateSinkObject.Save(2, &a.name) + stateSinkObject.Save(3, &a.ns) +} + +func (a *abstractEndpoint) afterLoad() {} + +// +checklocksignore +func (a *abstractEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.ep) + stateSourceObject.Load(1, &a.socket) + stateSourceObject.Load(2, &a.name) + stateSourceObject.Load(3, &a.ns) +} + +func (a *AbstractSocketNamespace) StateTypeName() string { + return "pkg/sentry/kernel.AbstractSocketNamespace" +} + +func (a *AbstractSocketNamespace) StateFields() []string { + return []string{ + "endpoints", + } +} + +func (a *AbstractSocketNamespace) beforeSave() {} + +// +checklocksignore +func (a *AbstractSocketNamespace) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.endpoints) +} + +func (a *AbstractSocketNamespace) afterLoad() {} + +// +checklocksignore +func (a *AbstractSocketNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.endpoints) +} + +func (c *Cgroup) StateTypeName() string { + return "pkg/sentry/kernel.Cgroup" +} + +func (c *Cgroup) StateFields() []string { + return []string{ + "Dentry", + "CgroupImpl", + } +} + +func (c *Cgroup) beforeSave() {} + +// +checklocksignore +func (c *Cgroup) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.Dentry) + stateSinkObject.Save(1, &c.CgroupImpl) +} + +func (c *Cgroup) afterLoad() {} + +// +checklocksignore +func (c *Cgroup) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.Dentry) + stateSourceObject.Load(1, &c.CgroupImpl) +} + +func (h *hierarchy) StateTypeName() string { + return "pkg/sentry/kernel.hierarchy" +} + +func (h *hierarchy) StateFields() []string { + return []string{ + "id", + "controllers", + "fs", + } +} + +func (h *hierarchy) beforeSave() {} + +// +checklocksignore +func (h *hierarchy) StateSave(stateSinkObject state.Sink) { + h.beforeSave() + stateSinkObject.Save(0, &h.id) + stateSinkObject.Save(1, &h.controllers) + stateSinkObject.Save(2, &h.fs) +} + +func (h *hierarchy) afterLoad() {} + +// +checklocksignore +func (h *hierarchy) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &h.id) + stateSourceObject.Load(1, &h.controllers) + stateSourceObject.Load(2, &h.fs) +} + +func (r *CgroupRegistry) StateTypeName() string { + return "pkg/sentry/kernel.CgroupRegistry" +} + +func (r *CgroupRegistry) StateFields() []string { + return []string{ + "lastHierarchyID", + "controllers", + "hierarchies", + } +} + +func (r *CgroupRegistry) beforeSave() {} + +// +checklocksignore +func (r *CgroupRegistry) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.lastHierarchyID) + stateSinkObject.Save(1, &r.controllers) + stateSinkObject.Save(2, &r.hierarchies) +} + +func (r *CgroupRegistry) afterLoad() {} + +// +checklocksignore +func (r *CgroupRegistry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.lastHierarchyID) + stateSourceObject.Load(1, &r.controllers) + stateSourceObject.Load(2, &r.hierarchies) +} + +func (f *FDFlags) StateTypeName() string { + return "pkg/sentry/kernel.FDFlags" +} + +func (f *FDFlags) StateFields() []string { + return []string{ + "CloseOnExec", + } +} + +func (f *FDFlags) beforeSave() {} + +// +checklocksignore +func (f *FDFlags) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.CloseOnExec) +} + +func (f *FDFlags) afterLoad() {} + +// +checklocksignore +func (f *FDFlags) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.CloseOnExec) +} + +func (d *descriptor) StateTypeName() string { + return "pkg/sentry/kernel.descriptor" +} + +func (d *descriptor) StateFields() []string { + return []string{ + "file", + "fileVFS2", + "flags", + } +} + +func (d *descriptor) beforeSave() {} + +// +checklocksignore +func (d *descriptor) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.file) + stateSinkObject.Save(1, &d.fileVFS2) + stateSinkObject.Save(2, &d.flags) +} + +func (d *descriptor) afterLoad() {} + +// +checklocksignore +func (d *descriptor) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.file) + stateSourceObject.Load(1, &d.fileVFS2) + stateSourceObject.Load(2, &d.flags) +} + +func (f *FDTable) StateTypeName() string { + return "pkg/sentry/kernel.FDTable" +} + +func (f *FDTable) StateFields() []string { + return []string{ + "FDTableRefs", + "k", + "descriptorTable", + } +} + +func (f *FDTable) beforeSave() {} + +// +checklocksignore +func (f *FDTable) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + var descriptorTableValue map[int32]descriptor + descriptorTableValue = f.saveDescriptorTable() + stateSinkObject.SaveValue(2, descriptorTableValue) + stateSinkObject.Save(0, &f.FDTableRefs) + stateSinkObject.Save(1, &f.k) +} + +func (f *FDTable) afterLoad() {} + +// +checklocksignore +func (f *FDTable) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.FDTableRefs) + stateSourceObject.Load(1, &f.k) + stateSourceObject.LoadValue(2, new(map[int32]descriptor), func(y interface{}) { f.loadDescriptorTable(y.(map[int32]descriptor)) }) +} + +func (r *FDTableRefs) StateTypeName() string { + return "pkg/sentry/kernel.FDTableRefs" +} + +func (r *FDTableRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *FDTableRefs) beforeSave() {} + +// +checklocksignore +func (r *FDTableRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *FDTableRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (f *FSContext) StateTypeName() string { + return "pkg/sentry/kernel.FSContext" +} + +func (f *FSContext) StateFields() []string { + return []string{ + "FSContextRefs", + "root", + "rootVFS2", + "cwd", + "cwdVFS2", + "umask", + } +} + +func (f *FSContext) beforeSave() {} + +// +checklocksignore +func (f *FSContext) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.FSContextRefs) + stateSinkObject.Save(1, &f.root) + stateSinkObject.Save(2, &f.rootVFS2) + stateSinkObject.Save(3, &f.cwd) + stateSinkObject.Save(4, &f.cwdVFS2) + stateSinkObject.Save(5, &f.umask) +} + +func (f *FSContext) afterLoad() {} + +// +checklocksignore +func (f *FSContext) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.FSContextRefs) + stateSourceObject.Load(1, &f.root) + stateSourceObject.Load(2, &f.rootVFS2) + stateSourceObject.Load(3, &f.cwd) + stateSourceObject.Load(4, &f.cwdVFS2) + stateSourceObject.Load(5, &f.umask) +} + +func (r *FSContextRefs) StateTypeName() string { + return "pkg/sentry/kernel.FSContextRefs" +} + +func (r *FSContextRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *FSContextRefs) beforeSave() {} + +// +checklocksignore +func (r *FSContextRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *FSContextRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (i *IPCNamespace) StateTypeName() string { + return "pkg/sentry/kernel.IPCNamespace" +} + +func (i *IPCNamespace) StateFields() []string { + return []string{ + "IPCNamespaceRefs", + "userNS", + "queues", + "semaphores", + "shms", + } +} + +func (i *IPCNamespace) beforeSave() {} + +// +checklocksignore +func (i *IPCNamespace) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.IPCNamespaceRefs) + stateSinkObject.Save(1, &i.userNS) + stateSinkObject.Save(2, &i.queues) + stateSinkObject.Save(3, &i.semaphores) + stateSinkObject.Save(4, &i.shms) +} + +func (i *IPCNamespace) afterLoad() {} + +// +checklocksignore +func (i *IPCNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.IPCNamespaceRefs) + stateSourceObject.Load(1, &i.userNS) + stateSourceObject.Load(2, &i.queues) + stateSourceObject.Load(3, &i.semaphores) + stateSourceObject.Load(4, &i.shms) +} + +func (r *IPCNamespaceRefs) StateTypeName() string { + return "pkg/sentry/kernel.IPCNamespaceRefs" +} + +func (r *IPCNamespaceRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *IPCNamespaceRefs) beforeSave() {} + +// +checklocksignore +func (r *IPCNamespaceRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *IPCNamespaceRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (k *Kernel) StateTypeName() string { + return "pkg/sentry/kernel.Kernel" +} + +func (k *Kernel) StateFields() []string { + return []string{ + "featureSet", + "timekeeper", + "tasks", + "rootUserNamespace", + "rootNetworkNamespace", + "applicationCores", + "useHostCores", + "extraAuxv", + "vdso", + "rootUTSNamespace", + "rootIPCNamespace", + "rootAbstractSocketNamespace", + "futexes", + "globalInit", + "syslog", + "runningTasks", + "cpuClock", + "cpuClockTickerDisabled", + "cpuClockTickerSetting", + "uniqueID", + "nextInotifyCookie", + "netlinkPorts", + "danglingEndpoints", + "sockets", + "socketsVFS2", + "nextSocketRecord", + "deviceRegistry", + "DirentCacheLimiter", + "SpecialOpts", + "vfs", + "hostMount", + "pipeMount", + "shmMount", + "socketMount", + "SleepForAddressSpaceActivation", + "ptraceExceptions", + "YAMAPtraceScope", + "cgroupRegistry", + } +} + +func (k *Kernel) beforeSave() {} + +// +checklocksignore +func (k *Kernel) StateSave(stateSinkObject state.Sink) { + k.beforeSave() + var danglingEndpointsValue []tcpip.Endpoint + danglingEndpointsValue = k.saveDanglingEndpoints() + stateSinkObject.SaveValue(22, danglingEndpointsValue) + var deviceRegistryValue *device.Registry + deviceRegistryValue = k.saveDeviceRegistry() + stateSinkObject.SaveValue(26, deviceRegistryValue) + stateSinkObject.Save(0, &k.featureSet) + stateSinkObject.Save(1, &k.timekeeper) + stateSinkObject.Save(2, &k.tasks) + stateSinkObject.Save(3, &k.rootUserNamespace) + stateSinkObject.Save(4, &k.rootNetworkNamespace) + stateSinkObject.Save(5, &k.applicationCores) + stateSinkObject.Save(6, &k.useHostCores) + stateSinkObject.Save(7, &k.extraAuxv) + stateSinkObject.Save(8, &k.vdso) + stateSinkObject.Save(9, &k.rootUTSNamespace) + stateSinkObject.Save(10, &k.rootIPCNamespace) + stateSinkObject.Save(11, &k.rootAbstractSocketNamespace) + stateSinkObject.Save(12, &k.futexes) + stateSinkObject.Save(13, &k.globalInit) + stateSinkObject.Save(14, &k.syslog) + stateSinkObject.Save(15, &k.runningTasks) + stateSinkObject.Save(16, &k.cpuClock) + stateSinkObject.Save(17, &k.cpuClockTickerDisabled) + stateSinkObject.Save(18, &k.cpuClockTickerSetting) + stateSinkObject.Save(19, &k.uniqueID) + stateSinkObject.Save(20, &k.nextInotifyCookie) + stateSinkObject.Save(21, &k.netlinkPorts) + stateSinkObject.Save(23, &k.sockets) + stateSinkObject.Save(24, &k.socketsVFS2) + stateSinkObject.Save(25, &k.nextSocketRecord) + stateSinkObject.Save(27, &k.DirentCacheLimiter) + stateSinkObject.Save(28, &k.SpecialOpts) + stateSinkObject.Save(29, &k.vfs) + stateSinkObject.Save(30, &k.hostMount) + stateSinkObject.Save(31, &k.pipeMount) + stateSinkObject.Save(32, &k.shmMount) + stateSinkObject.Save(33, &k.socketMount) + stateSinkObject.Save(34, &k.SleepForAddressSpaceActivation) + stateSinkObject.Save(35, &k.ptraceExceptions) + stateSinkObject.Save(36, &k.YAMAPtraceScope) + stateSinkObject.Save(37, &k.cgroupRegistry) +} + +func (k *Kernel) afterLoad() {} + +// +checklocksignore +func (k *Kernel) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &k.featureSet) + stateSourceObject.Load(1, &k.timekeeper) + stateSourceObject.Load(2, &k.tasks) + stateSourceObject.Load(3, &k.rootUserNamespace) + stateSourceObject.Load(4, &k.rootNetworkNamespace) + stateSourceObject.Load(5, &k.applicationCores) + stateSourceObject.Load(6, &k.useHostCores) + stateSourceObject.Load(7, &k.extraAuxv) + stateSourceObject.Load(8, &k.vdso) + stateSourceObject.Load(9, &k.rootUTSNamespace) + stateSourceObject.Load(10, &k.rootIPCNamespace) + stateSourceObject.Load(11, &k.rootAbstractSocketNamespace) + stateSourceObject.Load(12, &k.futexes) + stateSourceObject.Load(13, &k.globalInit) + stateSourceObject.Load(14, &k.syslog) + stateSourceObject.Load(15, &k.runningTasks) + stateSourceObject.Load(16, &k.cpuClock) + stateSourceObject.Load(17, &k.cpuClockTickerDisabled) + stateSourceObject.Load(18, &k.cpuClockTickerSetting) + stateSourceObject.Load(19, &k.uniqueID) + stateSourceObject.Load(20, &k.nextInotifyCookie) + stateSourceObject.Load(21, &k.netlinkPorts) + stateSourceObject.Load(23, &k.sockets) + stateSourceObject.Load(24, &k.socketsVFS2) + stateSourceObject.Load(25, &k.nextSocketRecord) + stateSourceObject.Load(27, &k.DirentCacheLimiter) + stateSourceObject.Load(28, &k.SpecialOpts) + stateSourceObject.Load(29, &k.vfs) + stateSourceObject.Load(30, &k.hostMount) + stateSourceObject.Load(31, &k.pipeMount) + stateSourceObject.Load(32, &k.shmMount) + stateSourceObject.Load(33, &k.socketMount) + stateSourceObject.Load(34, &k.SleepForAddressSpaceActivation) + stateSourceObject.Load(35, &k.ptraceExceptions) + stateSourceObject.Load(36, &k.YAMAPtraceScope) + stateSourceObject.Load(37, &k.cgroupRegistry) + stateSourceObject.LoadValue(22, new([]tcpip.Endpoint), func(y interface{}) { k.loadDanglingEndpoints(y.([]tcpip.Endpoint)) }) + stateSourceObject.LoadValue(26, new(*device.Registry), func(y interface{}) { k.loadDeviceRegistry(y.(*device.Registry)) }) +} + +func (s *SocketRecord) StateTypeName() string { + return "pkg/sentry/kernel.SocketRecord" +} + +func (s *SocketRecord) StateFields() []string { + return []string{ + "k", + "Sock", + "SockVFS2", + "ID", + } +} + +func (s *SocketRecord) beforeSave() {} + +// +checklocksignore +func (s *SocketRecord) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.k) + stateSinkObject.Save(1, &s.Sock) + stateSinkObject.Save(2, &s.SockVFS2) + stateSinkObject.Save(3, &s.ID) +} + +func (s *SocketRecord) afterLoad() {} + +// +checklocksignore +func (s *SocketRecord) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.k) + stateSourceObject.Load(1, &s.Sock) + stateSourceObject.Load(2, &s.SockVFS2) + stateSourceObject.Load(3, &s.ID) +} + +func (s *SocketRecordVFS1) StateTypeName() string { + return "pkg/sentry/kernel.SocketRecordVFS1" +} + +func (s *SocketRecordVFS1) StateFields() []string { + return []string{ + "socketEntry", + "SocketRecord", + } +} + +func (s *SocketRecordVFS1) beforeSave() {} + +// +checklocksignore +func (s *SocketRecordVFS1) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketEntry) + stateSinkObject.Save(1, &s.SocketRecord) +} + +func (s *SocketRecordVFS1) afterLoad() {} + +// +checklocksignore +func (s *SocketRecordVFS1) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketEntry) + stateSourceObject.Load(1, &s.SocketRecord) +} + +func (p *pendingSignals) StateTypeName() string { + return "pkg/sentry/kernel.pendingSignals" +} + +func (p *pendingSignals) StateFields() []string { + return []string{ + "signals", + } +} + +func (p *pendingSignals) beforeSave() {} + +// +checklocksignore +func (p *pendingSignals) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var signalsValue []savedPendingSignal + signalsValue = p.saveSignals() + stateSinkObject.SaveValue(0, signalsValue) +} + +func (p *pendingSignals) afterLoad() {} + +// +checklocksignore +func (p *pendingSignals) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new([]savedPendingSignal), func(y interface{}) { p.loadSignals(y.([]savedPendingSignal)) }) +} + +func (p *pendingSignalQueue) StateTypeName() string { + return "pkg/sentry/kernel.pendingSignalQueue" +} + +func (p *pendingSignalQueue) StateFields() []string { + return []string{ + "pendingSignalList", + "length", + } +} + +func (p *pendingSignalQueue) beforeSave() {} + +// +checklocksignore +func (p *pendingSignalQueue) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.pendingSignalList) + stateSinkObject.Save(1, &p.length) +} + +func (p *pendingSignalQueue) afterLoad() {} + +// +checklocksignore +func (p *pendingSignalQueue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.pendingSignalList) + stateSourceObject.Load(1, &p.length) +} + +func (p *pendingSignal) StateTypeName() string { + return "pkg/sentry/kernel.pendingSignal" +} + +func (p *pendingSignal) StateFields() []string { + return []string{ + "pendingSignalEntry", + "SignalInfo", + "timer", + } +} + +func (p *pendingSignal) beforeSave() {} + +// +checklocksignore +func (p *pendingSignal) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.pendingSignalEntry) + stateSinkObject.Save(1, &p.SignalInfo) + stateSinkObject.Save(2, &p.timer) +} + +func (p *pendingSignal) afterLoad() {} + +// +checklocksignore +func (p *pendingSignal) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.pendingSignalEntry) + stateSourceObject.Load(1, &p.SignalInfo) + stateSourceObject.Load(2, &p.timer) +} + +func (l *pendingSignalList) StateTypeName() string { + return "pkg/sentry/kernel.pendingSignalList" +} + +func (l *pendingSignalList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *pendingSignalList) beforeSave() {} + +// +checklocksignore +func (l *pendingSignalList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *pendingSignalList) afterLoad() {} + +// +checklocksignore +func (l *pendingSignalList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *pendingSignalEntry) StateTypeName() string { + return "pkg/sentry/kernel.pendingSignalEntry" +} + +func (e *pendingSignalEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *pendingSignalEntry) beforeSave() {} + +// +checklocksignore +func (e *pendingSignalEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *pendingSignalEntry) afterLoad() {} + +// +checklocksignore +func (e *pendingSignalEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (s *savedPendingSignal) StateTypeName() string { + return "pkg/sentry/kernel.savedPendingSignal" +} + +func (s *savedPendingSignal) StateFields() []string { + return []string{ + "si", + "timer", + } +} + +func (s *savedPendingSignal) beforeSave() {} + +// +checklocksignore +func (s *savedPendingSignal) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.si) + stateSinkObject.Save(1, &s.timer) +} + +func (s *savedPendingSignal) afterLoad() {} + +// +checklocksignore +func (s *savedPendingSignal) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.si) + stateSourceObject.Load(1, &s.timer) +} + +func (it *IntervalTimer) StateTypeName() string { + return "pkg/sentry/kernel.IntervalTimer" +} + +func (it *IntervalTimer) StateFields() []string { + return []string{ + "timer", + "target", + "signo", + "id", + "sigval", + "group", + "sigpending", + "sigorphan", + "overrunCur", + "overrunLast", + } +} + +func (it *IntervalTimer) beforeSave() {} + +// +checklocksignore +func (it *IntervalTimer) StateSave(stateSinkObject state.Sink) { + it.beforeSave() + stateSinkObject.Save(0, &it.timer) + stateSinkObject.Save(1, &it.target) + stateSinkObject.Save(2, &it.signo) + stateSinkObject.Save(3, &it.id) + stateSinkObject.Save(4, &it.sigval) + stateSinkObject.Save(5, &it.group) + stateSinkObject.Save(6, &it.sigpending) + stateSinkObject.Save(7, &it.sigorphan) + stateSinkObject.Save(8, &it.overrunCur) + stateSinkObject.Save(9, &it.overrunLast) +} + +func (it *IntervalTimer) afterLoad() {} + +// +checklocksignore +func (it *IntervalTimer) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &it.timer) + stateSourceObject.Load(1, &it.target) + stateSourceObject.Load(2, &it.signo) + stateSourceObject.Load(3, &it.id) + stateSourceObject.Load(4, &it.sigval) + stateSourceObject.Load(5, &it.group) + stateSourceObject.Load(6, &it.sigpending) + stateSourceObject.Load(7, &it.sigorphan) + stateSourceObject.Load(8, &it.overrunCur) + stateSourceObject.Load(9, &it.overrunLast) +} + +func (l *processGroupList) StateTypeName() string { + return "pkg/sentry/kernel.processGroupList" +} + +func (l *processGroupList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *processGroupList) beforeSave() {} + +// +checklocksignore +func (l *processGroupList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *processGroupList) afterLoad() {} + +// +checklocksignore +func (l *processGroupList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *processGroupEntry) StateTypeName() string { + return "pkg/sentry/kernel.processGroupEntry" +} + +func (e *processGroupEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *processGroupEntry) beforeSave() {} + +// +checklocksignore +func (e *processGroupEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *processGroupEntry) afterLoad() {} + +// +checklocksignore +func (e *processGroupEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (r *ProcessGroupRefs) StateTypeName() string { + return "pkg/sentry/kernel.ProcessGroupRefs" +} + +func (r *ProcessGroupRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *ProcessGroupRefs) beforeSave() {} + +// +checklocksignore +func (r *ProcessGroupRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *ProcessGroupRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (p *ptraceOptions) StateTypeName() string { + return "pkg/sentry/kernel.ptraceOptions" +} + +func (p *ptraceOptions) StateFields() []string { + return []string{ + "ExitKill", + "SysGood", + "TraceClone", + "TraceExec", + "TraceExit", + "TraceFork", + "TraceSeccomp", + "TraceVfork", + "TraceVforkDone", + } +} + +func (p *ptraceOptions) beforeSave() {} + +// +checklocksignore +func (p *ptraceOptions) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.ExitKill) + stateSinkObject.Save(1, &p.SysGood) + stateSinkObject.Save(2, &p.TraceClone) + stateSinkObject.Save(3, &p.TraceExec) + stateSinkObject.Save(4, &p.TraceExit) + stateSinkObject.Save(5, &p.TraceFork) + stateSinkObject.Save(6, &p.TraceSeccomp) + stateSinkObject.Save(7, &p.TraceVfork) + stateSinkObject.Save(8, &p.TraceVforkDone) +} + +func (p *ptraceOptions) afterLoad() {} + +// +checklocksignore +func (p *ptraceOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.ExitKill) + stateSourceObject.Load(1, &p.SysGood) + stateSourceObject.Load(2, &p.TraceClone) + stateSourceObject.Load(3, &p.TraceExec) + stateSourceObject.Load(4, &p.TraceExit) + stateSourceObject.Load(5, &p.TraceFork) + stateSourceObject.Load(6, &p.TraceSeccomp) + stateSourceObject.Load(7, &p.TraceVfork) + stateSourceObject.Load(8, &p.TraceVforkDone) +} + +func (s *ptraceStop) StateTypeName() string { + return "pkg/sentry/kernel.ptraceStop" +} + +func (s *ptraceStop) StateFields() []string { + return []string{ + "frozen", + "listen", + } +} + +func (s *ptraceStop) beforeSave() {} + +// +checklocksignore +func (s *ptraceStop) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.frozen) + stateSinkObject.Save(1, &s.listen) +} + +func (s *ptraceStop) afterLoad() {} + +// +checklocksignore +func (s *ptraceStop) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.frozen) + stateSourceObject.Load(1, &s.listen) +} + +func (o *OldRSeqCriticalRegion) StateTypeName() string { + return "pkg/sentry/kernel.OldRSeqCriticalRegion" +} + +func (o *OldRSeqCriticalRegion) StateFields() []string { + return []string{ + "CriticalSection", + "Restart", + } +} + +func (o *OldRSeqCriticalRegion) beforeSave() {} + +// +checklocksignore +func (o *OldRSeqCriticalRegion) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.CriticalSection) + stateSinkObject.Save(1, &o.Restart) +} + +func (o *OldRSeqCriticalRegion) afterLoad() {} + +// +checklocksignore +func (o *OldRSeqCriticalRegion) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.CriticalSection) + stateSourceObject.Load(1, &o.Restart) +} + +func (l *sessionList) StateTypeName() string { + return "pkg/sentry/kernel.sessionList" +} + +func (l *sessionList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *sessionList) beforeSave() {} + +// +checklocksignore +func (l *sessionList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *sessionList) afterLoad() {} + +// +checklocksignore +func (l *sessionList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *sessionEntry) StateTypeName() string { + return "pkg/sentry/kernel.sessionEntry" +} + +func (e *sessionEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *sessionEntry) beforeSave() {} + +// +checklocksignore +func (e *sessionEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *sessionEntry) afterLoad() {} + +// +checklocksignore +func (e *sessionEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (r *SessionRefs) StateTypeName() string { + return "pkg/sentry/kernel.SessionRefs" +} + +func (r *SessionRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *SessionRefs) beforeSave() {} + +// +checklocksignore +func (r *SessionRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *SessionRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (s *Session) StateTypeName() string { + return "pkg/sentry/kernel.Session" +} + +func (s *Session) StateFields() []string { + return []string{ + "SessionRefs", + "leader", + "id", + "foreground", + "processGroups", + "sessionEntry", + } +} + +func (s *Session) beforeSave() {} + +// +checklocksignore +func (s *Session) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SessionRefs) + stateSinkObject.Save(1, &s.leader) + stateSinkObject.Save(2, &s.id) + stateSinkObject.Save(3, &s.foreground) + stateSinkObject.Save(4, &s.processGroups) + stateSinkObject.Save(5, &s.sessionEntry) +} + +func (s *Session) afterLoad() {} + +// +checklocksignore +func (s *Session) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SessionRefs) + stateSourceObject.Load(1, &s.leader) + stateSourceObject.Load(2, &s.id) + stateSourceObject.Load(3, &s.foreground) + stateSourceObject.Load(4, &s.processGroups) + stateSourceObject.Load(5, &s.sessionEntry) +} + +func (pg *ProcessGroup) StateTypeName() string { + return "pkg/sentry/kernel.ProcessGroup" +} + +func (pg *ProcessGroup) StateFields() []string { + return []string{ + "refs", + "originator", + "id", + "session", + "ancestors", + "processGroupEntry", + } +} + +func (pg *ProcessGroup) beforeSave() {} + +// +checklocksignore +func (pg *ProcessGroup) StateSave(stateSinkObject state.Sink) { + pg.beforeSave() + stateSinkObject.Save(0, &pg.refs) + stateSinkObject.Save(1, &pg.originator) + stateSinkObject.Save(2, &pg.id) + stateSinkObject.Save(3, &pg.session) + stateSinkObject.Save(4, &pg.ancestors) + stateSinkObject.Save(5, &pg.processGroupEntry) +} + +func (pg *ProcessGroup) afterLoad() {} + +// +checklocksignore +func (pg *ProcessGroup) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &pg.refs) + stateSourceObject.Load(1, &pg.originator) + stateSourceObject.Load(2, &pg.id) + stateSourceObject.Load(3, &pg.session) + stateSourceObject.Load(4, &pg.ancestors) + stateSourceObject.Load(5, &pg.processGroupEntry) +} + +func (sh *SignalHandlers) StateTypeName() string { + return "pkg/sentry/kernel.SignalHandlers" +} + +func (sh *SignalHandlers) StateFields() []string { + return []string{ + "actions", + } +} + +func (sh *SignalHandlers) beforeSave() {} + +// +checklocksignore +func (sh *SignalHandlers) StateSave(stateSinkObject state.Sink) { + sh.beforeSave() + stateSinkObject.Save(0, &sh.actions) +} + +func (sh *SignalHandlers) afterLoad() {} + +// +checklocksignore +func (sh *SignalHandlers) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sh.actions) +} + +func (l *socketList) StateTypeName() string { + return "pkg/sentry/kernel.socketList" +} + +func (l *socketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *socketList) beforeSave() {} + +// +checklocksignore +func (l *socketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *socketList) afterLoad() {} + +// +checklocksignore +func (l *socketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *socketEntry) StateTypeName() string { + return "pkg/sentry/kernel.socketEntry" +} + +func (e *socketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *socketEntry) beforeSave() {} + +// +checklocksignore +func (e *socketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *socketEntry) afterLoad() {} + +// +checklocksignore +func (e *socketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (s *syscallTableInfo) StateTypeName() string { + return "pkg/sentry/kernel.syscallTableInfo" +} + +func (s *syscallTableInfo) StateFields() []string { + return []string{ + "OS", + "Arch", + } +} + +func (s *syscallTableInfo) beforeSave() {} + +// +checklocksignore +func (s *syscallTableInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.OS) + stateSinkObject.Save(1, &s.Arch) +} + +func (s *syscallTableInfo) afterLoad() {} + +// +checklocksignore +func (s *syscallTableInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.OS) + stateSourceObject.Load(1, &s.Arch) +} + +func (s *syslog) StateTypeName() string { + return "pkg/sentry/kernel.syslog" +} + +func (s *syslog) StateFields() []string { + return []string{ + "msg", + } +} + +func (s *syslog) beforeSave() {} + +// +checklocksignore +func (s *syslog) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.msg) +} + +func (s *syslog) afterLoad() {} + +// +checklocksignore +func (s *syslog) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.msg) +} + +func (t *Task) StateTypeName() string { + return "pkg/sentry/kernel.Task" +} + +func (t *Task) StateFields() []string { + return []string{ + "taskNode", + "runState", + "taskWorkCount", + "taskWork", + "haveSyscallReturn", + "gosched", + "yieldCount", + "pendingSignals", + "signalMask", + "realSignalMask", + "haveSavedSignalMask", + "savedSignalMask", + "signalStack", + "groupStopPending", + "groupStopAcknowledged", + "trapStopPending", + "trapNotifyPending", + "stop", + "exitStatus", + "syscallRestartBlock", + "k", + "containerID", + "image", + "fsContext", + "fdTable", + "vforkParent", + "exitState", + "exitTracerNotified", + "exitTracerAcked", + "exitParentNotified", + "exitParentAcked", + "ptraceTracer", + "ptraceTracees", + "ptraceSeized", + "ptraceOpts", + "ptraceSyscallMode", + "ptraceSinglestep", + "ptraceCode", + "ptraceSiginfo", + "ptraceEventMsg", + "ptraceYAMAExceptionAdded", + "ioUsage", + "creds", + "utsns", + "ipcns", + "abstractSockets", + "mountNamespaceVFS2", + "parentDeathSignal", + "syscallFilters", + "cleartid", + "allowedCPUMask", + "cpu", + "niceness", + "numaPolicy", + "numaNodeMask", + "netns", + "rseqCPU", + "oldRSeqCPUAddr", + "rseqAddr", + "rseqSignature", + "robustList", + "startTime", + "kcov", + "cgroups", + } +} + +func (t *Task) beforeSave() {} + +// +checklocksignore +func (t *Task) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + if !state.IsZeroValue(&t.signalQueue) { + state.Failf("signalQueue is %#v, expected zero", &t.signalQueue) + } + var ptraceTracerValue *Task + ptraceTracerValue = t.savePtraceTracer() + stateSinkObject.SaveValue(31, ptraceTracerValue) + var syscallFiltersValue []bpf.Program + syscallFiltersValue = t.saveSyscallFilters() + stateSinkObject.SaveValue(48, syscallFiltersValue) + stateSinkObject.Save(0, &t.taskNode) + stateSinkObject.Save(1, &t.runState) + stateSinkObject.Save(2, &t.taskWorkCount) + stateSinkObject.Save(3, &t.taskWork) + stateSinkObject.Save(4, &t.haveSyscallReturn) + stateSinkObject.Save(5, &t.gosched) + stateSinkObject.Save(6, &t.yieldCount) + stateSinkObject.Save(7, &t.pendingSignals) + stateSinkObject.Save(8, &t.signalMask) + stateSinkObject.Save(9, &t.realSignalMask) + stateSinkObject.Save(10, &t.haveSavedSignalMask) + stateSinkObject.Save(11, &t.savedSignalMask) + stateSinkObject.Save(12, &t.signalStack) + stateSinkObject.Save(13, &t.groupStopPending) + stateSinkObject.Save(14, &t.groupStopAcknowledged) + stateSinkObject.Save(15, &t.trapStopPending) + stateSinkObject.Save(16, &t.trapNotifyPending) + stateSinkObject.Save(17, &t.stop) + stateSinkObject.Save(18, &t.exitStatus) + stateSinkObject.Save(19, &t.syscallRestartBlock) + stateSinkObject.Save(20, &t.k) + stateSinkObject.Save(21, &t.containerID) + stateSinkObject.Save(22, &t.image) + stateSinkObject.Save(23, &t.fsContext) + stateSinkObject.Save(24, &t.fdTable) + stateSinkObject.Save(25, &t.vforkParent) + stateSinkObject.Save(26, &t.exitState) + stateSinkObject.Save(27, &t.exitTracerNotified) + stateSinkObject.Save(28, &t.exitTracerAcked) + stateSinkObject.Save(29, &t.exitParentNotified) + stateSinkObject.Save(30, &t.exitParentAcked) + stateSinkObject.Save(32, &t.ptraceTracees) + stateSinkObject.Save(33, &t.ptraceSeized) + stateSinkObject.Save(34, &t.ptraceOpts) + stateSinkObject.Save(35, &t.ptraceSyscallMode) + stateSinkObject.Save(36, &t.ptraceSinglestep) + stateSinkObject.Save(37, &t.ptraceCode) + stateSinkObject.Save(38, &t.ptraceSiginfo) + stateSinkObject.Save(39, &t.ptraceEventMsg) + stateSinkObject.Save(40, &t.ptraceYAMAExceptionAdded) + stateSinkObject.Save(41, &t.ioUsage) + stateSinkObject.Save(42, &t.creds) + stateSinkObject.Save(43, &t.utsns) + stateSinkObject.Save(44, &t.ipcns) + stateSinkObject.Save(45, &t.abstractSockets) + stateSinkObject.Save(46, &t.mountNamespaceVFS2) + stateSinkObject.Save(47, &t.parentDeathSignal) + stateSinkObject.Save(49, &t.cleartid) + stateSinkObject.Save(50, &t.allowedCPUMask) + stateSinkObject.Save(51, &t.cpu) + stateSinkObject.Save(52, &t.niceness) + stateSinkObject.Save(53, &t.numaPolicy) + stateSinkObject.Save(54, &t.numaNodeMask) + stateSinkObject.Save(55, &t.netns) + stateSinkObject.Save(56, &t.rseqCPU) + stateSinkObject.Save(57, &t.oldRSeqCPUAddr) + stateSinkObject.Save(58, &t.rseqAddr) + stateSinkObject.Save(59, &t.rseqSignature) + stateSinkObject.Save(60, &t.robustList) + stateSinkObject.Save(61, &t.startTime) + stateSinkObject.Save(62, &t.kcov) + stateSinkObject.Save(63, &t.cgroups) +} + +// +checklocksignore +func (t *Task) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.taskNode) + stateSourceObject.Load(1, &t.runState) + stateSourceObject.Load(2, &t.taskWorkCount) + stateSourceObject.Load(3, &t.taskWork) + stateSourceObject.Load(4, &t.haveSyscallReturn) + stateSourceObject.Load(5, &t.gosched) + stateSourceObject.Load(6, &t.yieldCount) + stateSourceObject.Load(7, &t.pendingSignals) + stateSourceObject.Load(8, &t.signalMask) + stateSourceObject.Load(9, &t.realSignalMask) + stateSourceObject.Load(10, &t.haveSavedSignalMask) + stateSourceObject.Load(11, &t.savedSignalMask) + stateSourceObject.Load(12, &t.signalStack) + stateSourceObject.Load(13, &t.groupStopPending) + stateSourceObject.Load(14, &t.groupStopAcknowledged) + stateSourceObject.Load(15, &t.trapStopPending) + stateSourceObject.Load(16, &t.trapNotifyPending) + stateSourceObject.Load(17, &t.stop) + stateSourceObject.Load(18, &t.exitStatus) + stateSourceObject.Load(19, &t.syscallRestartBlock) + stateSourceObject.Load(20, &t.k) + stateSourceObject.Load(21, &t.containerID) + stateSourceObject.Load(22, &t.image) + stateSourceObject.Load(23, &t.fsContext) + stateSourceObject.Load(24, &t.fdTable) + stateSourceObject.Load(25, &t.vforkParent) + stateSourceObject.Load(26, &t.exitState) + stateSourceObject.Load(27, &t.exitTracerNotified) + stateSourceObject.Load(28, &t.exitTracerAcked) + stateSourceObject.Load(29, &t.exitParentNotified) + stateSourceObject.Load(30, &t.exitParentAcked) + stateSourceObject.Load(32, &t.ptraceTracees) + stateSourceObject.Load(33, &t.ptraceSeized) + stateSourceObject.Load(34, &t.ptraceOpts) + stateSourceObject.Load(35, &t.ptraceSyscallMode) + stateSourceObject.Load(36, &t.ptraceSinglestep) + stateSourceObject.Load(37, &t.ptraceCode) + stateSourceObject.Load(38, &t.ptraceSiginfo) + stateSourceObject.Load(39, &t.ptraceEventMsg) + stateSourceObject.Load(40, &t.ptraceYAMAExceptionAdded) + stateSourceObject.Load(41, &t.ioUsage) + stateSourceObject.Load(42, &t.creds) + stateSourceObject.Load(43, &t.utsns) + stateSourceObject.Load(44, &t.ipcns) + stateSourceObject.Load(45, &t.abstractSockets) + stateSourceObject.Load(46, &t.mountNamespaceVFS2) + stateSourceObject.Load(47, &t.parentDeathSignal) + stateSourceObject.Load(49, &t.cleartid) + stateSourceObject.Load(50, &t.allowedCPUMask) + stateSourceObject.Load(51, &t.cpu) + stateSourceObject.Load(52, &t.niceness) + stateSourceObject.Load(53, &t.numaPolicy) + stateSourceObject.Load(54, &t.numaNodeMask) + stateSourceObject.Load(55, &t.netns) + stateSourceObject.Load(56, &t.rseqCPU) + stateSourceObject.Load(57, &t.oldRSeqCPUAddr) + stateSourceObject.Load(58, &t.rseqAddr) + stateSourceObject.Load(59, &t.rseqSignature) + stateSourceObject.Load(60, &t.robustList) + stateSourceObject.Load(61, &t.startTime) + stateSourceObject.Load(62, &t.kcov) + stateSourceObject.Load(63, &t.cgroups) + stateSourceObject.LoadValue(31, new(*Task), func(y interface{}) { t.loadPtraceTracer(y.(*Task)) }) + stateSourceObject.LoadValue(48, new([]bpf.Program), func(y interface{}) { t.loadSyscallFilters(y.([]bpf.Program)) }) + stateSourceObject.AfterLoad(t.afterLoad) +} + +func (r *runSyscallAfterPtraceEventClone) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallAfterPtraceEventClone" +} + +func (r *runSyscallAfterPtraceEventClone) StateFields() []string { + return []string{ + "vforkChild", + "vforkChildTID", + } +} + +func (r *runSyscallAfterPtraceEventClone) beforeSave() {} + +// +checklocksignore +func (r *runSyscallAfterPtraceEventClone) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.vforkChild) + stateSinkObject.Save(1, &r.vforkChildTID) +} + +func (r *runSyscallAfterPtraceEventClone) afterLoad() {} + +// +checklocksignore +func (r *runSyscallAfterPtraceEventClone) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.vforkChild) + stateSourceObject.Load(1, &r.vforkChildTID) +} + +func (r *runSyscallAfterVforkStop) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallAfterVforkStop" +} + +func (r *runSyscallAfterVforkStop) StateFields() []string { + return []string{ + "childTID", + } +} + +func (r *runSyscallAfterVforkStop) beforeSave() {} + +// +checklocksignore +func (r *runSyscallAfterVforkStop) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.childTID) +} + +func (r *runSyscallAfterVforkStop) afterLoad() {} + +// +checklocksignore +func (r *runSyscallAfterVforkStop) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.childTID) +} + +func (v *vforkStop) StateTypeName() string { + return "pkg/sentry/kernel.vforkStop" +} + +func (v *vforkStop) StateFields() []string { + return []string{} +} + +func (v *vforkStop) beforeSave() {} + +// +checklocksignore +func (v *vforkStop) StateSave(stateSinkObject state.Sink) { + v.beforeSave() +} + +func (v *vforkStop) afterLoad() {} + +// +checklocksignore +func (v *vforkStop) StateLoad(stateSourceObject state.Source) { +} + +func (e *execStop) StateTypeName() string { + return "pkg/sentry/kernel.execStop" +} + +func (e *execStop) StateFields() []string { + return []string{} +} + +func (e *execStop) beforeSave() {} + +// +checklocksignore +func (e *execStop) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *execStop) afterLoad() {} + +// +checklocksignore +func (e *execStop) StateLoad(stateSourceObject state.Source) { +} + +func (r *runSyscallAfterExecStop) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallAfterExecStop" +} + +func (r *runSyscallAfterExecStop) StateFields() []string { + return []string{ + "image", + } +} + +func (r *runSyscallAfterExecStop) beforeSave() {} + +// +checklocksignore +func (r *runSyscallAfterExecStop) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.image) +} + +func (r *runSyscallAfterExecStop) afterLoad() {} + +// +checklocksignore +func (r *runSyscallAfterExecStop) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.image) +} + +func (r *runExit) StateTypeName() string { + return "pkg/sentry/kernel.runExit" +} + +func (r *runExit) StateFields() []string { + return []string{} +} + +func (r *runExit) beforeSave() {} + +// +checklocksignore +func (r *runExit) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runExit) afterLoad() {} + +// +checklocksignore +func (r *runExit) StateLoad(stateSourceObject state.Source) { +} + +func (r *runExitMain) StateTypeName() string { + return "pkg/sentry/kernel.runExitMain" +} + +func (r *runExitMain) StateFields() []string { + return []string{} +} + +func (r *runExitMain) beforeSave() {} + +// +checklocksignore +func (r *runExitMain) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runExitMain) afterLoad() {} + +// +checklocksignore +func (r *runExitMain) StateLoad(stateSourceObject state.Source) { +} + +func (r *runExitNotify) StateTypeName() string { + return "pkg/sentry/kernel.runExitNotify" +} + +func (r *runExitNotify) StateFields() []string { + return []string{} +} + +func (r *runExitNotify) beforeSave() {} + +// +checklocksignore +func (r *runExitNotify) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runExitNotify) afterLoad() {} + +// +checklocksignore +func (r *runExitNotify) StateLoad(stateSourceObject state.Source) { +} + +func (image *TaskImage) StateTypeName() string { + return "pkg/sentry/kernel.TaskImage" +} + +func (image *TaskImage) StateFields() []string { + return []string{ + "Name", + "Arch", + "MemoryManager", + "fu", + "st", + } +} + +func (image *TaskImage) beforeSave() {} + +// +checklocksignore +func (image *TaskImage) StateSave(stateSinkObject state.Sink) { + image.beforeSave() + var stValue syscallTableInfo + stValue = image.saveSt() + stateSinkObject.SaveValue(4, stValue) + stateSinkObject.Save(0, &image.Name) + stateSinkObject.Save(1, &image.Arch) + stateSinkObject.Save(2, &image.MemoryManager) + stateSinkObject.Save(3, &image.fu) +} + +func (image *TaskImage) afterLoad() {} + +// +checklocksignore +func (image *TaskImage) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &image.Name) + stateSourceObject.Load(1, &image.Arch) + stateSourceObject.Load(2, &image.MemoryManager) + stateSourceObject.Load(3, &image.fu) + stateSourceObject.LoadValue(4, new(syscallTableInfo), func(y interface{}) { image.loadSt(y.(syscallTableInfo)) }) +} + +func (l *taskList) StateTypeName() string { + return "pkg/sentry/kernel.taskList" +} + +func (l *taskList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *taskList) beforeSave() {} + +// +checklocksignore +func (l *taskList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *taskList) afterLoad() {} + +// +checklocksignore +func (l *taskList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *taskEntry) StateTypeName() string { + return "pkg/sentry/kernel.taskEntry" +} + +func (e *taskEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *taskEntry) beforeSave() {} + +// +checklocksignore +func (e *taskEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *taskEntry) afterLoad() {} + +// +checklocksignore +func (e *taskEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (app *runApp) StateTypeName() string { + return "pkg/sentry/kernel.runApp" +} + +func (app *runApp) StateFields() []string { + return []string{} +} + +func (app *runApp) beforeSave() {} + +// +checklocksignore +func (app *runApp) StateSave(stateSinkObject state.Sink) { + app.beforeSave() +} + +func (app *runApp) afterLoad() {} + +// +checklocksignore +func (app *runApp) StateLoad(stateSourceObject state.Source) { +} + +func (ts *TaskGoroutineSchedInfo) StateTypeName() string { + return "pkg/sentry/kernel.TaskGoroutineSchedInfo" +} + +func (ts *TaskGoroutineSchedInfo) StateFields() []string { + return []string{ + "Timestamp", + "State", + "UserTicks", + "SysTicks", + } +} + +func (ts *TaskGoroutineSchedInfo) beforeSave() {} + +// +checklocksignore +func (ts *TaskGoroutineSchedInfo) StateSave(stateSinkObject state.Sink) { + ts.beforeSave() + stateSinkObject.Save(0, &ts.Timestamp) + stateSinkObject.Save(1, &ts.State) + stateSinkObject.Save(2, &ts.UserTicks) + stateSinkObject.Save(3, &ts.SysTicks) +} + +func (ts *TaskGoroutineSchedInfo) afterLoad() {} + +// +checklocksignore +func (ts *TaskGoroutineSchedInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ts.Timestamp) + stateSourceObject.Load(1, &ts.State) + stateSourceObject.Load(2, &ts.UserTicks) + stateSourceObject.Load(3, &ts.SysTicks) +} + +func (tc *taskClock) StateTypeName() string { + return "pkg/sentry/kernel.taskClock" +} + +func (tc *taskClock) StateFields() []string { + return []string{ + "t", + "includeSys", + } +} + +func (tc *taskClock) beforeSave() {} + +// +checklocksignore +func (tc *taskClock) StateSave(stateSinkObject state.Sink) { + tc.beforeSave() + stateSinkObject.Save(0, &tc.t) + stateSinkObject.Save(1, &tc.includeSys) +} + +func (tc *taskClock) afterLoad() {} + +// +checklocksignore +func (tc *taskClock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tc.t) + stateSourceObject.Load(1, &tc.includeSys) +} + +func (tgc *tgClock) StateTypeName() string { + return "pkg/sentry/kernel.tgClock" +} + +func (tgc *tgClock) StateFields() []string { + return []string{ + "tg", + "includeSys", + } +} + +func (tgc *tgClock) beforeSave() {} + +// +checklocksignore +func (tgc *tgClock) StateSave(stateSinkObject state.Sink) { + tgc.beforeSave() + stateSinkObject.Save(0, &tgc.tg) + stateSinkObject.Save(1, &tgc.includeSys) +} + +func (tgc *tgClock) afterLoad() {} + +// +checklocksignore +func (tgc *tgClock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tgc.tg) + stateSourceObject.Load(1, &tgc.includeSys) +} + +func (g *groupStop) StateTypeName() string { + return "pkg/sentry/kernel.groupStop" +} + +func (g *groupStop) StateFields() []string { + return []string{} +} + +func (g *groupStop) beforeSave() {} + +// +checklocksignore +func (g *groupStop) StateSave(stateSinkObject state.Sink) { + g.beforeSave() +} + +func (g *groupStop) afterLoad() {} + +// +checklocksignore +func (g *groupStop) StateLoad(stateSourceObject state.Source) { +} + +func (r *runInterrupt) StateTypeName() string { + return "pkg/sentry/kernel.runInterrupt" +} + +func (r *runInterrupt) StateFields() []string { + return []string{} +} + +func (r *runInterrupt) beforeSave() {} + +// +checklocksignore +func (r *runInterrupt) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runInterrupt) afterLoad() {} + +// +checklocksignore +func (r *runInterrupt) StateLoad(stateSourceObject state.Source) { +} + +func (r *runInterruptAfterSignalDeliveryStop) StateTypeName() string { + return "pkg/sentry/kernel.runInterruptAfterSignalDeliveryStop" +} + +func (r *runInterruptAfterSignalDeliveryStop) StateFields() []string { + return []string{} +} + +func (r *runInterruptAfterSignalDeliveryStop) beforeSave() {} + +// +checklocksignore +func (r *runInterruptAfterSignalDeliveryStop) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runInterruptAfterSignalDeliveryStop) afterLoad() {} + +// +checklocksignore +func (r *runInterruptAfterSignalDeliveryStop) StateLoad(stateSourceObject state.Source) { +} + +func (r *runSyscallAfterSyscallEnterStop) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallAfterSyscallEnterStop" +} + +func (r *runSyscallAfterSyscallEnterStop) StateFields() []string { + return []string{} +} + +func (r *runSyscallAfterSyscallEnterStop) beforeSave() {} + +// +checklocksignore +func (r *runSyscallAfterSyscallEnterStop) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runSyscallAfterSyscallEnterStop) afterLoad() {} + +// +checklocksignore +func (r *runSyscallAfterSyscallEnterStop) StateLoad(stateSourceObject state.Source) { +} + +func (r *runSyscallAfterSysemuStop) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallAfterSysemuStop" +} + +func (r *runSyscallAfterSysemuStop) StateFields() []string { + return []string{} +} + +func (r *runSyscallAfterSysemuStop) beforeSave() {} + +// +checklocksignore +func (r *runSyscallAfterSysemuStop) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runSyscallAfterSysemuStop) afterLoad() {} + +// +checklocksignore +func (r *runSyscallAfterSysemuStop) StateLoad(stateSourceObject state.Source) { +} + +func (r *runSyscallReinvoke) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallReinvoke" +} + +func (r *runSyscallReinvoke) StateFields() []string { + return []string{} +} + +func (r *runSyscallReinvoke) beforeSave() {} + +// +checklocksignore +func (r *runSyscallReinvoke) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runSyscallReinvoke) afterLoad() {} + +// +checklocksignore +func (r *runSyscallReinvoke) StateLoad(stateSourceObject state.Source) { +} + +func (r *runSyscallExit) StateTypeName() string { + return "pkg/sentry/kernel.runSyscallExit" +} + +func (r *runSyscallExit) StateFields() []string { + return []string{} +} + +func (r *runSyscallExit) beforeSave() {} + +// +checklocksignore +func (r *runSyscallExit) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *runSyscallExit) afterLoad() {} + +// +checklocksignore +func (r *runSyscallExit) StateLoad(stateSourceObject state.Source) { +} + +func (tg *ThreadGroup) StateTypeName() string { + return "pkg/sentry/kernel.ThreadGroup" +} + +func (tg *ThreadGroup) StateFields() []string { + return []string{ + "threadGroupNode", + "signalHandlers", + "pendingSignals", + "groupStopDequeued", + "groupStopSignal", + "groupStopPendingCount", + "groupStopComplete", + "groupStopWaitable", + "groupContNotify", + "groupContInterrupted", + "groupContWaitable", + "exiting", + "exitStatus", + "terminationSignal", + "itimerRealTimer", + "itimerVirtSetting", + "itimerProfSetting", + "rlimitCPUSoftSetting", + "cpuTimersEnabled", + "timers", + "nextTimerID", + "exitedCPUStats", + "childCPUStats", + "ioUsage", + "maxRSS", + "childMaxRSS", + "limits", + "processGroup", + "execed", + "oldRSeqCritical", + "mounts", + "tty", + "oomScoreAdj", + } +} + +func (tg *ThreadGroup) beforeSave() {} + +// +checklocksignore +func (tg *ThreadGroup) StateSave(stateSinkObject state.Sink) { + tg.beforeSave() + var oldRSeqCriticalValue *OldRSeqCriticalRegion + oldRSeqCriticalValue = tg.saveOldRSeqCritical() + stateSinkObject.SaveValue(29, oldRSeqCriticalValue) + stateSinkObject.Save(0, &tg.threadGroupNode) + stateSinkObject.Save(1, &tg.signalHandlers) + stateSinkObject.Save(2, &tg.pendingSignals) + stateSinkObject.Save(3, &tg.groupStopDequeued) + stateSinkObject.Save(4, &tg.groupStopSignal) + stateSinkObject.Save(5, &tg.groupStopPendingCount) + stateSinkObject.Save(6, &tg.groupStopComplete) + stateSinkObject.Save(7, &tg.groupStopWaitable) + stateSinkObject.Save(8, &tg.groupContNotify) + stateSinkObject.Save(9, &tg.groupContInterrupted) + stateSinkObject.Save(10, &tg.groupContWaitable) + stateSinkObject.Save(11, &tg.exiting) + stateSinkObject.Save(12, &tg.exitStatus) + stateSinkObject.Save(13, &tg.terminationSignal) + stateSinkObject.Save(14, &tg.itimerRealTimer) + stateSinkObject.Save(15, &tg.itimerVirtSetting) + stateSinkObject.Save(16, &tg.itimerProfSetting) + stateSinkObject.Save(17, &tg.rlimitCPUSoftSetting) + stateSinkObject.Save(18, &tg.cpuTimersEnabled) + stateSinkObject.Save(19, &tg.timers) + stateSinkObject.Save(20, &tg.nextTimerID) + stateSinkObject.Save(21, &tg.exitedCPUStats) + stateSinkObject.Save(22, &tg.childCPUStats) + stateSinkObject.Save(23, &tg.ioUsage) + stateSinkObject.Save(24, &tg.maxRSS) + stateSinkObject.Save(25, &tg.childMaxRSS) + stateSinkObject.Save(26, &tg.limits) + stateSinkObject.Save(27, &tg.processGroup) + stateSinkObject.Save(28, &tg.execed) + stateSinkObject.Save(30, &tg.mounts) + stateSinkObject.Save(31, &tg.tty) + stateSinkObject.Save(32, &tg.oomScoreAdj) +} + +func (tg *ThreadGroup) afterLoad() {} + +// +checklocksignore +func (tg *ThreadGroup) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tg.threadGroupNode) + stateSourceObject.Load(1, &tg.signalHandlers) + stateSourceObject.Load(2, &tg.pendingSignals) + stateSourceObject.Load(3, &tg.groupStopDequeued) + stateSourceObject.Load(4, &tg.groupStopSignal) + stateSourceObject.Load(5, &tg.groupStopPendingCount) + stateSourceObject.Load(6, &tg.groupStopComplete) + stateSourceObject.Load(7, &tg.groupStopWaitable) + stateSourceObject.Load(8, &tg.groupContNotify) + stateSourceObject.Load(9, &tg.groupContInterrupted) + stateSourceObject.Load(10, &tg.groupContWaitable) + stateSourceObject.Load(11, &tg.exiting) + stateSourceObject.Load(12, &tg.exitStatus) + stateSourceObject.Load(13, &tg.terminationSignal) + stateSourceObject.Load(14, &tg.itimerRealTimer) + stateSourceObject.Load(15, &tg.itimerVirtSetting) + stateSourceObject.Load(16, &tg.itimerProfSetting) + stateSourceObject.Load(17, &tg.rlimitCPUSoftSetting) + stateSourceObject.Load(18, &tg.cpuTimersEnabled) + stateSourceObject.Load(19, &tg.timers) + stateSourceObject.Load(20, &tg.nextTimerID) + stateSourceObject.Load(21, &tg.exitedCPUStats) + stateSourceObject.Load(22, &tg.childCPUStats) + stateSourceObject.Load(23, &tg.ioUsage) + stateSourceObject.Load(24, &tg.maxRSS) + stateSourceObject.Load(25, &tg.childMaxRSS) + stateSourceObject.Load(26, &tg.limits) + stateSourceObject.Load(27, &tg.processGroup) + stateSourceObject.Load(28, &tg.execed) + stateSourceObject.Load(30, &tg.mounts) + stateSourceObject.Load(31, &tg.tty) + stateSourceObject.Load(32, &tg.oomScoreAdj) + stateSourceObject.LoadValue(29, new(*OldRSeqCriticalRegion), func(y interface{}) { tg.loadOldRSeqCritical(y.(*OldRSeqCriticalRegion)) }) +} + +func (l *itimerRealListener) StateTypeName() string { + return "pkg/sentry/kernel.itimerRealListener" +} + +func (l *itimerRealListener) StateFields() []string { + return []string{ + "tg", + } +} + +func (l *itimerRealListener) beforeSave() {} + +// +checklocksignore +func (l *itimerRealListener) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.tg) +} + +func (l *itimerRealListener) afterLoad() {} + +// +checklocksignore +func (l *itimerRealListener) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.tg) +} + +func (ts *TaskSet) StateTypeName() string { + return "pkg/sentry/kernel.TaskSet" +} + +func (ts *TaskSet) StateFields() []string { + return []string{ + "Root", + "sessions", + } +} + +func (ts *TaskSet) beforeSave() {} + +// +checklocksignore +func (ts *TaskSet) StateSave(stateSinkObject state.Sink) { + ts.beforeSave() + stateSinkObject.Save(0, &ts.Root) + stateSinkObject.Save(1, &ts.sessions) +} + +func (ts *TaskSet) afterLoad() {} + +// +checklocksignore +func (ts *TaskSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ts.Root) + stateSourceObject.Load(1, &ts.sessions) +} + +func (ns *PIDNamespace) StateTypeName() string { + return "pkg/sentry/kernel.PIDNamespace" +} + +func (ns *PIDNamespace) StateFields() []string { + return []string{ + "owner", + "parent", + "userns", + "last", + "tasks", + "tids", + "tgids", + "sessions", + "sids", + "processGroups", + "pgids", + "exiting", + } +} + +func (ns *PIDNamespace) beforeSave() {} + +// +checklocksignore +func (ns *PIDNamespace) StateSave(stateSinkObject state.Sink) { + ns.beforeSave() + stateSinkObject.Save(0, &ns.owner) + stateSinkObject.Save(1, &ns.parent) + stateSinkObject.Save(2, &ns.userns) + stateSinkObject.Save(3, &ns.last) + stateSinkObject.Save(4, &ns.tasks) + stateSinkObject.Save(5, &ns.tids) + stateSinkObject.Save(6, &ns.tgids) + stateSinkObject.Save(7, &ns.sessions) + stateSinkObject.Save(8, &ns.sids) + stateSinkObject.Save(9, &ns.processGroups) + stateSinkObject.Save(10, &ns.pgids) + stateSinkObject.Save(11, &ns.exiting) +} + +func (ns *PIDNamespace) afterLoad() {} + +// +checklocksignore +func (ns *PIDNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ns.owner) + stateSourceObject.Load(1, &ns.parent) + stateSourceObject.Load(2, &ns.userns) + stateSourceObject.Load(3, &ns.last) + stateSourceObject.Load(4, &ns.tasks) + stateSourceObject.Load(5, &ns.tids) + stateSourceObject.Load(6, &ns.tgids) + stateSourceObject.Load(7, &ns.sessions) + stateSourceObject.Load(8, &ns.sids) + stateSourceObject.Load(9, &ns.processGroups) + stateSourceObject.Load(10, &ns.pgids) + stateSourceObject.Load(11, &ns.exiting) +} + +func (t *threadGroupNode) StateTypeName() string { + return "pkg/sentry/kernel.threadGroupNode" +} + +func (t *threadGroupNode) StateFields() []string { + return []string{ + "pidns", + "leader", + "execing", + "tasks", + "tasksCount", + "liveTasks", + "activeTasks", + } +} + +func (t *threadGroupNode) beforeSave() {} + +// +checklocksignore +func (t *threadGroupNode) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.pidns) + stateSinkObject.Save(1, &t.leader) + stateSinkObject.Save(2, &t.execing) + stateSinkObject.Save(3, &t.tasks) + stateSinkObject.Save(4, &t.tasksCount) + stateSinkObject.Save(5, &t.liveTasks) + stateSinkObject.Save(6, &t.activeTasks) +} + +func (t *threadGroupNode) afterLoad() {} + +// +checklocksignore +func (t *threadGroupNode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.pidns) + stateSourceObject.Load(1, &t.leader) + stateSourceObject.Load(2, &t.execing) + stateSourceObject.Load(3, &t.tasks) + stateSourceObject.Load(4, &t.tasksCount) + stateSourceObject.Load(5, &t.liveTasks) + stateSourceObject.Load(6, &t.activeTasks) +} + +func (t *taskNode) StateTypeName() string { + return "pkg/sentry/kernel.taskNode" +} + +func (t *taskNode) StateFields() []string { + return []string{ + "tg", + "taskEntry", + "parent", + "children", + "childPIDNamespace", + } +} + +func (t *taskNode) beforeSave() {} + +// +checklocksignore +func (t *taskNode) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.tg) + stateSinkObject.Save(1, &t.taskEntry) + stateSinkObject.Save(2, &t.parent) + stateSinkObject.Save(3, &t.children) + stateSinkObject.Save(4, &t.childPIDNamespace) +} + +func (t *taskNode) afterLoad() {} + +// +checklocksignore +func (t *taskNode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &t.tg) + stateSourceObject.Load(1, &t.taskEntry) + stateSourceObject.Load(2, &t.parent) + stateSourceObject.Load(3, &t.children) + stateSourceObject.Load(4, &t.childPIDNamespace) +} + +func (t *Timekeeper) StateTypeName() string { + return "pkg/sentry/kernel.Timekeeper" +} + +func (t *Timekeeper) StateFields() []string { + return []string{ + "realtimeClock", + "monotonicClock", + "bootTime", + "saveMonotonic", + "saveRealtime", + "params", + } +} + +// +checklocksignore +func (t *Timekeeper) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.realtimeClock) + stateSinkObject.Save(1, &t.monotonicClock) + stateSinkObject.Save(2, &t.bootTime) + stateSinkObject.Save(3, &t.saveMonotonic) + stateSinkObject.Save(4, &t.saveRealtime) + stateSinkObject.Save(5, &t.params) +} + +// +checklocksignore +func (t *Timekeeper) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.realtimeClock) + stateSourceObject.Load(1, &t.monotonicClock) + stateSourceObject.Load(2, &t.bootTime) + stateSourceObject.Load(3, &t.saveMonotonic) + stateSourceObject.Load(4, &t.saveRealtime) + stateSourceObject.Load(5, &t.params) + stateSourceObject.AfterLoad(t.afterLoad) +} + +func (tc *timekeeperClock) StateTypeName() string { + return "pkg/sentry/kernel.timekeeperClock" +} + +func (tc *timekeeperClock) StateFields() []string { + return []string{ + "tk", + "c", + } +} + +func (tc *timekeeperClock) beforeSave() {} + +// +checklocksignore +func (tc *timekeeperClock) StateSave(stateSinkObject state.Sink) { + tc.beforeSave() + stateSinkObject.Save(0, &tc.tk) + stateSinkObject.Save(1, &tc.c) +} + +func (tc *timekeeperClock) afterLoad() {} + +// +checklocksignore +func (tc *timekeeperClock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &tc.tk) + stateSourceObject.Load(1, &tc.c) +} + +func (t *TTY) StateTypeName() string { + return "pkg/sentry/kernel.TTY" +} + +func (t *TTY) StateFields() []string { + return []string{ + "Index", + "tg", + } +} + +func (t *TTY) beforeSave() {} + +// +checklocksignore +func (t *TTY) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Index) + stateSinkObject.Save(1, &t.tg) +} + +func (t *TTY) afterLoad() {} + +// +checklocksignore +func (t *TTY) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Index) + stateSourceObject.Load(1, &t.tg) +} + +func (u *UTSNamespace) StateTypeName() string { + return "pkg/sentry/kernel.UTSNamespace" +} + +func (u *UTSNamespace) StateFields() []string { + return []string{ + "hostName", + "domainName", + "userns", + } +} + +func (u *UTSNamespace) beforeSave() {} + +// +checklocksignore +func (u *UTSNamespace) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.hostName) + stateSinkObject.Save(1, &u.domainName) + stateSinkObject.Save(2, &u.userns) +} + +func (u *UTSNamespace) afterLoad() {} + +// +checklocksignore +func (u *UTSNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.hostName) + stateSourceObject.Load(1, &u.domainName) + stateSourceObject.Load(2, &u.userns) +} + +func (v *VDSOParamPage) StateTypeName() string { + return "pkg/sentry/kernel.VDSOParamPage" +} + +func (v *VDSOParamPage) StateFields() []string { + return []string{ + "mfp", + "fr", + "seq", + "copyScratchBuffer", + } +} + +func (v *VDSOParamPage) beforeSave() {} + +// +checklocksignore +func (v *VDSOParamPage) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + stateSinkObject.Save(0, &v.mfp) + stateSinkObject.Save(1, &v.fr) + stateSinkObject.Save(2, &v.seq) + stateSinkObject.Save(3, &v.copyScratchBuffer) +} + +func (v *VDSOParamPage) afterLoad() {} + +// +checklocksignore +func (v *VDSOParamPage) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.mfp) + stateSourceObject.Load(1, &v.fr) + stateSourceObject.Load(2, &v.seq) + stateSourceObject.Load(3, &v.copyScratchBuffer) +} + +func init() { + state.Register((*abstractEndpoint)(nil)) + state.Register((*AbstractSocketNamespace)(nil)) + state.Register((*Cgroup)(nil)) + state.Register((*hierarchy)(nil)) + state.Register((*CgroupRegistry)(nil)) + state.Register((*FDFlags)(nil)) + state.Register((*descriptor)(nil)) + state.Register((*FDTable)(nil)) + state.Register((*FDTableRefs)(nil)) + state.Register((*FSContext)(nil)) + state.Register((*FSContextRefs)(nil)) + state.Register((*IPCNamespace)(nil)) + state.Register((*IPCNamespaceRefs)(nil)) + state.Register((*Kernel)(nil)) + state.Register((*SocketRecord)(nil)) + state.Register((*SocketRecordVFS1)(nil)) + state.Register((*pendingSignals)(nil)) + state.Register((*pendingSignalQueue)(nil)) + state.Register((*pendingSignal)(nil)) + state.Register((*pendingSignalList)(nil)) + state.Register((*pendingSignalEntry)(nil)) + state.Register((*savedPendingSignal)(nil)) + state.Register((*IntervalTimer)(nil)) + state.Register((*processGroupList)(nil)) + state.Register((*processGroupEntry)(nil)) + state.Register((*ProcessGroupRefs)(nil)) + state.Register((*ptraceOptions)(nil)) + state.Register((*ptraceStop)(nil)) + state.Register((*OldRSeqCriticalRegion)(nil)) + state.Register((*sessionList)(nil)) + state.Register((*sessionEntry)(nil)) + state.Register((*SessionRefs)(nil)) + state.Register((*Session)(nil)) + state.Register((*ProcessGroup)(nil)) + state.Register((*SignalHandlers)(nil)) + state.Register((*socketList)(nil)) + state.Register((*socketEntry)(nil)) + state.Register((*syscallTableInfo)(nil)) + state.Register((*syslog)(nil)) + state.Register((*Task)(nil)) + state.Register((*runSyscallAfterPtraceEventClone)(nil)) + state.Register((*runSyscallAfterVforkStop)(nil)) + state.Register((*vforkStop)(nil)) + state.Register((*execStop)(nil)) + state.Register((*runSyscallAfterExecStop)(nil)) + state.Register((*runExit)(nil)) + state.Register((*runExitMain)(nil)) + state.Register((*runExitNotify)(nil)) + state.Register((*TaskImage)(nil)) + state.Register((*taskList)(nil)) + state.Register((*taskEntry)(nil)) + state.Register((*runApp)(nil)) + state.Register((*TaskGoroutineSchedInfo)(nil)) + state.Register((*taskClock)(nil)) + state.Register((*tgClock)(nil)) + state.Register((*groupStop)(nil)) + state.Register((*runInterrupt)(nil)) + state.Register((*runInterruptAfterSignalDeliveryStop)(nil)) + state.Register((*runSyscallAfterSyscallEnterStop)(nil)) + state.Register((*runSyscallAfterSysemuStop)(nil)) + state.Register((*runSyscallReinvoke)(nil)) + state.Register((*runSyscallExit)(nil)) + state.Register((*ThreadGroup)(nil)) + state.Register((*itimerRealListener)(nil)) + state.Register((*TaskSet)(nil)) + state.Register((*PIDNamespace)(nil)) + state.Register((*threadGroupNode)(nil)) + state.Register((*taskNode)(nil)) + state.Register((*Timekeeper)(nil)) + state.Register((*timekeeperClock)(nil)) + state.Register((*TTY)(nil)) + state.Register((*UTSNamespace)(nil)) + state.Register((*VDSOParamPage)(nil)) +} diff --git a/pkg/sentry/kernel/kernel_unsafe_abi_autogen_unsafe.go b/pkg/sentry/kernel/kernel_unsafe_abi_autogen_unsafe.go new file mode 100644 index 000000000..5d810c89c --- /dev/null +++ b/pkg/sentry/kernel/kernel_unsafe_abi_autogen_unsafe.go @@ -0,0 +1,7 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package kernel + +import ( +) + diff --git a/pkg/sentry/kernel/kernel_unsafe_state_autogen.go b/pkg/sentry/kernel/kernel_unsafe_state_autogen.go new file mode 100644 index 000000000..12130bf74 --- /dev/null +++ b/pkg/sentry/kernel/kernel_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package kernel diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD deleted file mode 100644 index 4486848d2..000000000 --- a/pkg/sentry/kernel/memevent/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "memevent", - srcs = ["memory_events.go"], - visibility = ["//:sandbox"], - deps = [ - ":memory_events_go_proto", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/kernel", - "//pkg/sentry/usage", - "//pkg/sync", - ], -) - -proto_library( - name = "memory_events", - srcs = ["memory_events.proto"], - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/kernel/memevent/memevent_state_autogen.go b/pkg/sentry/kernel/memevent/memevent_state_autogen.go new file mode 100644 index 000000000..4a1679fa9 --- /dev/null +++ b/pkg/sentry/kernel/memevent/memevent_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package memevent diff --git a/pkg/sentry/kernel/memevent/memory_events.proto b/pkg/sentry/kernel/memevent/memory_events.proto deleted file mode 100644 index bf8029ff5..000000000 --- a/pkg/sentry/kernel/memevent/memory_events.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// MemoryUsageEvent describes the memory usage of the sandbox at a single -// instant in time. These messages are emitted periodically on the eventchannel. -message MemoryUsageEvent { - // The total memory usage of the sandboxed application in bytes, calculated - // using the 'fast' method. - uint64 total = 1; - - // Memory used to back memory-mapped regions for files in the application, in - // bytes. This corresponds to the usage.MemoryKind.Mapped memory type. - uint64 mapped = 2; -} diff --git a/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go b/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go new file mode 100644 index 000000000..ed1b8a8ca --- /dev/null +++ b/pkg/sentry/kernel/memevent/memory_events_go_proto/memory_events.pb.go @@ -0,0 +1,158 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/sentry/kernel/memevent/memory_events.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type MemoryUsageEvent struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"` + Mapped uint64 `protobuf:"varint,2,opt,name=mapped,proto3" json:"mapped,omitempty"` +} + +func (x *MemoryUsageEvent) Reset() { + *x = MemoryUsageEvent{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_kernel_memevent_memory_events_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MemoryUsageEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MemoryUsageEvent) ProtoMessage() {} + +func (x *MemoryUsageEvent) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_kernel_memevent_memory_events_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MemoryUsageEvent.ProtoReflect.Descriptor instead. +func (*MemoryUsageEvent) Descriptor() ([]byte, []int) { + return file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescGZIP(), []int{0} +} + +func (x *MemoryUsageEvent) GetTotal() uint64 { + if x != nil { + return x.Total + } + return 0 +} + +func (x *MemoryUsageEvent) GetMapped() uint64 { + if x != nil { + return x.Mapped + } + return 0 +} + +var File_pkg_sentry_kernel_memevent_memory_events_proto protoreflect.FileDescriptor + +var file_pkg_sentry_kernel_memevent_memory_events_proto_rawDesc = []byte{ + 0x0a, 0x2e, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2f, 0x6b, 0x65, 0x72, + 0x6e, 0x65, 0x6c, 0x2f, 0x6d, 0x65, 0x6d, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x2f, 0x6d, 0x65, 0x6d, + 0x6f, 0x72, 0x79, 0x5f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x06, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x22, 0x40, 0x0a, 0x10, 0x4d, 0x65, 0x6d, 0x6f, + 0x72, 0x79, 0x55, 0x73, 0x61, 0x67, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, + 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x74, 0x6f, 0x74, + 0x61, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x04, 0x52, 0x06, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescOnce sync.Once + file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescData = file_pkg_sentry_kernel_memevent_memory_events_proto_rawDesc +) + +func file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescGZIP() []byte { + file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescOnce.Do(func() { + file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescData) + }) + return file_pkg_sentry_kernel_memevent_memory_events_proto_rawDescData +} + +var file_pkg_sentry_kernel_memevent_memory_events_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_pkg_sentry_kernel_memevent_memory_events_proto_goTypes = []interface{}{ + (*MemoryUsageEvent)(nil), // 0: gvisor.MemoryUsageEvent +} +var file_pkg_sentry_kernel_memevent_memory_events_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pkg_sentry_kernel_memevent_memory_events_proto_init() } +func file_pkg_sentry_kernel_memevent_memory_events_proto_init() { + if File_pkg_sentry_kernel_memevent_memory_events_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_sentry_kernel_memevent_memory_events_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MemoryUsageEvent); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_sentry_kernel_memevent_memory_events_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_sentry_kernel_memevent_memory_events_proto_goTypes, + DependencyIndexes: file_pkg_sentry_kernel_memevent_memory_events_proto_depIdxs, + MessageInfos: file_pkg_sentry_kernel_memevent_memory_events_proto_msgTypes, + }.Build() + File_pkg_sentry_kernel_memevent_memory_events_proto = out.File + file_pkg_sentry_kernel_memevent_memory_events_proto_rawDesc = nil + file_pkg_sentry_kernel_memevent_memory_events_proto_goTypes = nil + file_pkg_sentry_kernel_memevent_memory_events_proto_depIdxs = nil +} diff --git a/pkg/sentry/kernel/msgqueue/BUILD b/pkg/sentry/kernel/msgqueue/BUILD deleted file mode 100644 index 5ec11e1f6..000000000 --- a/pkg/sentry/kernel/msgqueue/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "message_list", - out = "message_list.go", - package = "msgqueue", - prefix = "msg", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Message", - "Linker": "*Message", - }, -) - -go_library( - name = "msgqueue", - srcs = [ - "message_list.go", - "msgqueue.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/ipc", - "//pkg/sentry/kernel/time", - "//pkg/sync", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/msgqueue/message_list.go b/pkg/sentry/kernel/msgqueue/message_list.go new file mode 100644 index 000000000..f2f2292e7 --- /dev/null +++ b/pkg/sentry/kernel/msgqueue/message_list.go @@ -0,0 +1,221 @@ +package msgqueue + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type msgElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (msgElementMapper) linkerFor(elem *Message) *Message { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type msgList struct { + head *Message + tail *Message +} + +// Reset resets list l to the empty state. +func (l *msgList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *msgList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *msgList) Front() *Message { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *msgList) Back() *Message { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *msgList) Len() (count int) { + for e := l.Front(); e != nil; e = (msgElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *msgList) PushFront(e *Message) { + linker := msgElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + msgElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *msgList) PushBack(e *Message) { + linker := msgElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + msgElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *msgList) PushBackList(m *msgList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + msgElementMapper{}.linkerFor(l.tail).SetNext(m.head) + msgElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *msgList) InsertAfter(b, e *Message) { + bLinker := msgElementMapper{}.linkerFor(b) + eLinker := msgElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + msgElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *msgList) InsertBefore(a, e *Message) { + aLinker := msgElementMapper{}.linkerFor(a) + eLinker := msgElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + msgElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *msgList) Remove(e *Message) { + linker := msgElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + msgElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + msgElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type msgEntry struct { + next *Message + prev *Message +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *msgEntry) Next() *Message { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *msgEntry) Prev() *Message { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *msgEntry) SetNext(elem *Message) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *msgEntry) SetPrev(elem *Message) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go b/pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go new file mode 100644 index 000000000..373e01cff --- /dev/null +++ b/pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go @@ -0,0 +1,194 @@ +// automatically generated by stateify. + +package msgqueue + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *msgList) StateTypeName() string { + return "pkg/sentry/kernel/msgqueue.msgList" +} + +func (l *msgList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *msgList) beforeSave() {} + +// +checklocksignore +func (l *msgList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *msgList) afterLoad() {} + +// +checklocksignore +func (l *msgList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *msgEntry) StateTypeName() string { + return "pkg/sentry/kernel/msgqueue.msgEntry" +} + +func (e *msgEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *msgEntry) beforeSave() {} + +// +checklocksignore +func (e *msgEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *msgEntry) afterLoad() {} + +// +checklocksignore +func (e *msgEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (r *Registry) StateTypeName() string { + return "pkg/sentry/kernel/msgqueue.Registry" +} + +func (r *Registry) StateFields() []string { + return []string{ + "reg", + } +} + +func (r *Registry) beforeSave() {} + +// +checklocksignore +func (r *Registry) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.reg) +} + +func (r *Registry) afterLoad() {} + +// +checklocksignore +func (r *Registry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.reg) +} + +func (q *Queue) StateTypeName() string { + return "pkg/sentry/kernel/msgqueue.Queue" +} + +func (q *Queue) StateFields() []string { + return []string{ + "registry", + "dead", + "obj", + "senders", + "receivers", + "messages", + "sendTime", + "receiveTime", + "changeTime", + "byteCount", + "messageCount", + "maxBytes", + "sendPID", + "receivePID", + } +} + +func (q *Queue) beforeSave() {} + +// +checklocksignore +func (q *Queue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.registry) + stateSinkObject.Save(1, &q.dead) + stateSinkObject.Save(2, &q.obj) + stateSinkObject.Save(3, &q.senders) + stateSinkObject.Save(4, &q.receivers) + stateSinkObject.Save(5, &q.messages) + stateSinkObject.Save(6, &q.sendTime) + stateSinkObject.Save(7, &q.receiveTime) + stateSinkObject.Save(8, &q.changeTime) + stateSinkObject.Save(9, &q.byteCount) + stateSinkObject.Save(10, &q.messageCount) + stateSinkObject.Save(11, &q.maxBytes) + stateSinkObject.Save(12, &q.sendPID) + stateSinkObject.Save(13, &q.receivePID) +} + +func (q *Queue) afterLoad() {} + +// +checklocksignore +func (q *Queue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.registry) + stateSourceObject.Load(1, &q.dead) + stateSourceObject.Load(2, &q.obj) + stateSourceObject.Load(3, &q.senders) + stateSourceObject.Load(4, &q.receivers) + stateSourceObject.Load(5, &q.messages) + stateSourceObject.Load(6, &q.sendTime) + stateSourceObject.Load(7, &q.receiveTime) + stateSourceObject.Load(8, &q.changeTime) + stateSourceObject.Load(9, &q.byteCount) + stateSourceObject.Load(10, &q.messageCount) + stateSourceObject.Load(11, &q.maxBytes) + stateSourceObject.Load(12, &q.sendPID) + stateSourceObject.Load(13, &q.receivePID) +} + +func (m *Message) StateTypeName() string { + return "pkg/sentry/kernel/msgqueue.Message" +} + +func (m *Message) StateFields() []string { + return []string{ + "msgEntry", + "Type", + "Text", + "Size", + } +} + +func (m *Message) beforeSave() {} + +// +checklocksignore +func (m *Message) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.msgEntry) + stateSinkObject.Save(1, &m.Type) + stateSinkObject.Save(2, &m.Text) + stateSinkObject.Save(3, &m.Size) +} + +func (m *Message) afterLoad() {} + +// +checklocksignore +func (m *Message) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.msgEntry) + stateSourceObject.Load(1, &m.Type) + stateSourceObject.Load(2, &m.Text) + stateSourceObject.Load(3, &m.Size) +} + +func init() { + state.Register((*msgList)(nil)) + state.Register((*msgEntry)(nil)) + state.Register((*Registry)(nil)) + state.Register((*Queue)(nil)) + state.Register((*Message)(nil)) +} diff --git a/pkg/sentry/kernel/pending_signals_list.go b/pkg/sentry/kernel/pending_signals_list.go new file mode 100644 index 000000000..d5b483b3e --- /dev/null +++ b/pkg/sentry/kernel/pending_signals_list.go @@ -0,0 +1,221 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pendingSignalElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pendingSignalElementMapper) linkerFor(elem *pendingSignal) *pendingSignal { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pendingSignalList struct { + head *pendingSignal + tail *pendingSignal +} + +// Reset resets list l to the empty state. +func (l *pendingSignalList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *pendingSignalList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *pendingSignalList) Front() *pendingSignal { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *pendingSignalList) Back() *pendingSignal { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *pendingSignalList) Len() (count int) { + for e := l.Front(); e != nil; e = (pendingSignalElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *pendingSignalList) PushFront(e *pendingSignal) { + linker := pendingSignalElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + pendingSignalElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *pendingSignalList) PushBack(e *pendingSignal) { + linker := pendingSignalElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *pendingSignalList) PushBackList(m *pendingSignalList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pendingSignalElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *pendingSignalList) InsertAfter(b, e *pendingSignal) { + bLinker := pendingSignalElementMapper{}.linkerFor(b) + eLinker := pendingSignalElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + pendingSignalElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *pendingSignalList) InsertBefore(a, e *pendingSignal) { + aLinker := pendingSignalElementMapper{}.linkerFor(a) + eLinker := pendingSignalElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + pendingSignalElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *pendingSignalList) Remove(e *pendingSignal) { + linker := pendingSignalElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + pendingSignalElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + pendingSignalElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pendingSignalEntry struct { + next *pendingSignal + prev *pendingSignal +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *pendingSignalEntry) Next() *pendingSignal { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *pendingSignalEntry) Prev() *pendingSignal { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *pendingSignalEntry) SetNext(elem *pendingSignal) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *pendingSignalEntry) SetPrev(elem *pendingSignal) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD deleted file mode 100644 index 5b2bac783..000000000 --- a/pkg/sentry/kernel/pipe/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "device.go", - "node.go", - "pipe.go", - "pipe_unsafe.go", - "pipe_util.go", - "reader.go", - "reader_writer.go", - "save_restore.go", - "vfs.go", - "writer.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal/primitive", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "pipe_test", - size = "small", - srcs = [ - "node_test.go", - "pipe_test.go", - ], - library = ":pipe", - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/contexttest", - "//pkg/sentry/fs", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go deleted file mode 100644 index 31bd7910a..000000000 --- a/pkg/sentry/kernel/pipe/node_test.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -type sleeper struct { - context.Context - ch chan struct{} -} - -func newSleeperContext(t *testing.T) context.Context { - return &sleeper{ - Context: contexttest.Context(t), - ch: make(chan struct{}), - } -} - -func (s *sleeper) SleepStart() <-chan struct{} { - return s.ch -} - -func (s *sleeper) SleepFinish(bool) { -} - -func (s *sleeper) Cancel() { - s.ch <- struct{}{} -} - -func (s *sleeper) Interrupted() bool { - return len(s.ch) != 0 -} - -type openResult struct { - *fs.File - error -} - -var perms fs.FilePermissions = fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, -} - -func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, doneChan chan<- struct{}) (*fs.File, error) { - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - d := fs.NewDirent(ctx, inode, "pipe") - file, err := n.GetFile(ctx, d, flags) - if err != nil { - t.Errorf("open with flags %+v failed: %v", flags, err) - return nil, err - } - if doneChan != nil { - doneChan <- struct{}{} - } - return file, err -} - -func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, resChan chan<- openResult) (*fs.File, error) { - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - d := fs.NewDirent(ctx, inode, "pipe") - file, err := n.GetFile(ctx, d, flags) - if resChan != nil { - resChan <- openResult{file, err} - } - return file, err -} - -func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(true, DefaultPipeSize) -} - -func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(false, DefaultPipeSize) -} - -// assertRecvBlocks ensures that a recv attempt on c blocks for at least -// blockDuration. This is useful for checking that a goroutine that is supposed -// to be executing a blocking operation is actually blocking. -func assertRecvBlocks(t *testing.T, c <-chan struct{}, blockDuration time.Duration, failMsg string) { - select { - case <-c: - t.Fatalf(failMsg) - case <-time.After(blockDuration): - // Ok, blocked for the required duration. - } -} - -func TestReadOpenBlocksForWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - // Verify that the open for read is blocking. - assertRecvBlocks(t, rDone, time.Millisecond*100, - "open for read not blocking with no writers") - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - <-wDone - <-rDone -} - -func TestWriteOpenBlocksForReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - // Verify that the open for write is blocking - assertRecvBlocks(t, wDone, time.Millisecond*100, - "open for write not blocking with no readers") - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - <-rDone - <-wDone -} - -func TestMultipleWriteOpenDoesntCountAsReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone1 := make(chan struct{}) - rDone2 := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone1) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone2) - - assertRecvBlocks(t, rDone1, time.Millisecond*100, - "open for read didn't block with no writers") - assertRecvBlocks(t, rDone2, time.Millisecond*100, - "open for read didn't block with no writers") - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - <-wDone - <-rDone2 - <-rDone1 -} - -func TestClosedReaderBlocksWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil) - rFile.DecRef(ctx) - - wDone := make(chan struct{}) - // This open for write should block because the reader is now gone. - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - assertRecvBlocks(t, wDone, time.Millisecond*100, - "open for write didn't block with no concurrent readers") - - // Open for read again. This should unblock the open for write. - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - <-rDone - <-wDone -} - -func TestReadWriteOpenNeverBlocks(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rwDone := make(chan struct{}) - // Open for read-write never wait for a reader or writer, even if the - // nonblocking flag is not set. - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true, NonBlocking: false}, rwDone) - <-rwDone -} - -func TestReadWriteOpenUnblocksReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - rwDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone) - - <-rwDone - <-rDone -} - -func TestReadWriteOpenUnblocksWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - rwDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone) - - <-rwDone - <-wDone -} - -func TestBlockedOpenIsCancellable(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - done := make(chan openResult) - go testOpen(ctx, t, f, fs.FileFlags{Read: true}, done) - select { - case <-done: - t.Fatalf("open for read didn't block with no writers") - case <-time.After(time.Millisecond * 100): - // Ok. - } - - ctx.(*sleeper).Cancel() - // If the cancel on the sleeper didn't work, the open for read would never - // return. - res := <-done - if res.error != linuxerr.ErrInterrupted { - t.Fatalf("Cancellation didn't cause GetFile to return fs.ErrInterrupted, got %v.", - res.error) - } -} - -func TestNonblockingReadOpenFileNoWriters(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for read failed with error %v.", err) - } -} - -func TestNonblockingWriteOpenFileNoReaders(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); !linuxerr.Equals(linuxerr.ENXIO, err) { - t.Fatalf("Nonblocking open for write failed unexpected error %v.", err) - } -} - -func TestNonBlockingReadOpenWithWriter(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - // Open for write blocks since there are no readers yet. - assertRecvBlocks(t, wDone, time.Millisecond*100, - "Open for write didn't block with no reader.") - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for read failed with error %v.", err) - } - - // Open for write should now be unblocked. - <-wDone -} - -func TestNonBlockingWriteOpenWithReader(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - // Open for write blocked, since no reader yet. - assertRecvBlocks(t, rDone, time.Millisecond*100, - "Open for reader didn't block with no writer.") - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for write failed with error %v.", err) - } - - // Open for write should now be unblocked. - <-rDone -} - -func TestAnonReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newAnonPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true}, nil); err != nil { - t.Fatalf("open anon pipe for read failed: %v", err) - } -} - -func TestAnonWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newAnonPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true}, nil); err != nil { - t.Fatalf("open anon pipe for write failed: %v", err) - } -} diff --git a/pkg/sentry/kernel/pipe/pipe_state_autogen.go b/pkg/sentry/kernel/pipe/pipe_state_autogen.go new file mode 100644 index 000000000..57451e951 --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_state_autogen.go @@ -0,0 +1,227 @@ +// automatically generated by stateify. + +package pipe + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *inodeOperations) StateTypeName() string { + return "pkg/sentry/kernel/pipe.inodeOperations" +} + +func (i *inodeOperations) StateFields() []string { + return []string{ + "InodeSimpleAttributes", + "p", + } +} + +func (i *inodeOperations) beforeSave() {} + +// +checklocksignore +func (i *inodeOperations) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.InodeSimpleAttributes) + stateSinkObject.Save(1, &i.p) +} + +func (i *inodeOperations) afterLoad() {} + +// +checklocksignore +func (i *inodeOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.InodeSimpleAttributes) + stateSourceObject.Load(1, &i.p) +} + +func (p *Pipe) StateTypeName() string { + return "pkg/sentry/kernel/pipe.Pipe" +} + +func (p *Pipe) StateFields() []string { + return []string{ + "isNamed", + "readers", + "writers", + "buf", + "off", + "size", + "max", + "hadWriter", + } +} + +func (p *Pipe) beforeSave() {} + +// +checklocksignore +func (p *Pipe) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.isNamed) + stateSinkObject.Save(1, &p.readers) + stateSinkObject.Save(2, &p.writers) + stateSinkObject.Save(3, &p.buf) + stateSinkObject.Save(4, &p.off) + stateSinkObject.Save(5, &p.size) + stateSinkObject.Save(6, &p.max) + stateSinkObject.Save(7, &p.hadWriter) +} + +// +checklocksignore +func (p *Pipe) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.isNamed) + stateSourceObject.Load(1, &p.readers) + stateSourceObject.Load(2, &p.writers) + stateSourceObject.Load(3, &p.buf) + stateSourceObject.Load(4, &p.off) + stateSourceObject.Load(5, &p.size) + stateSourceObject.Load(6, &p.max) + stateSourceObject.Load(7, &p.hadWriter) + stateSourceObject.AfterLoad(p.afterLoad) +} + +func (r *Reader) StateTypeName() string { + return "pkg/sentry/kernel/pipe.Reader" +} + +func (r *Reader) StateFields() []string { + return []string{ + "ReaderWriter", + } +} + +func (r *Reader) beforeSave() {} + +// +checklocksignore +func (r *Reader) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.ReaderWriter) +} + +func (r *Reader) afterLoad() {} + +// +checklocksignore +func (r *Reader) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.ReaderWriter) +} + +func (rw *ReaderWriter) StateTypeName() string { + return "pkg/sentry/kernel/pipe.ReaderWriter" +} + +func (rw *ReaderWriter) StateFields() []string { + return []string{ + "Pipe", + } +} + +func (rw *ReaderWriter) beforeSave() {} + +// +checklocksignore +func (rw *ReaderWriter) StateSave(stateSinkObject state.Sink) { + rw.beforeSave() + stateSinkObject.Save(0, &rw.Pipe) +} + +func (rw *ReaderWriter) afterLoad() {} + +// +checklocksignore +func (rw *ReaderWriter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rw.Pipe) +} + +func (vp *VFSPipe) StateTypeName() string { + return "pkg/sentry/kernel/pipe.VFSPipe" +} + +func (vp *VFSPipe) StateFields() []string { + return []string{ + "pipe", + } +} + +func (vp *VFSPipe) beforeSave() {} + +// +checklocksignore +func (vp *VFSPipe) StateSave(stateSinkObject state.Sink) { + vp.beforeSave() + stateSinkObject.Save(0, &vp.pipe) +} + +func (vp *VFSPipe) afterLoad() {} + +// +checklocksignore +func (vp *VFSPipe) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &vp.pipe) +} + +func (fd *VFSPipeFD) StateTypeName() string { + return "pkg/sentry/kernel/pipe.VFSPipeFD" +} + +func (fd *VFSPipeFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "pipe", + } +} + +func (fd *VFSPipeFD) beforeSave() {} + +// +checklocksignore +func (fd *VFSPipeFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &fd.LockFD) + stateSinkObject.Save(4, &fd.pipe) +} + +func (fd *VFSPipeFD) afterLoad() {} + +// +checklocksignore +func (fd *VFSPipeFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &fd.LockFD) + stateSourceObject.Load(4, &fd.pipe) +} + +func (w *Writer) StateTypeName() string { + return "pkg/sentry/kernel/pipe.Writer" +} + +func (w *Writer) StateFields() []string { + return []string{ + "ReaderWriter", + } +} + +func (w *Writer) beforeSave() {} + +// +checklocksignore +func (w *Writer) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.ReaderWriter) +} + +func (w *Writer) afterLoad() {} + +// +checklocksignore +func (w *Writer) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.ReaderWriter) +} + +func init() { + state.Register((*inodeOperations)(nil)) + state.Register((*Pipe)(nil)) + state.Register((*Reader)(nil)) + state.Register((*ReaderWriter)(nil)) + state.Register((*VFSPipe)(nil)) + state.Register((*VFSPipeFD)(nil)) + state.Register((*Writer)(nil)) +} diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go deleted file mode 100644 index aa3ab305d..000000000 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestPipeRW(t *testing.T) { - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536) - defer r.DecRef(ctx) - defer w.DecRef(ctx) - - msg := []byte("here's some bytes") - wantN := int64(len(msg)) - n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if n != wantN || err != nil { - t.Fatalf("Writev: got (%d, %v), wanted (%d, nil)", n, err, wantN) - } - - buf := make([]byte, len(msg)) - n, err = r.Readv(ctx, usermem.BytesIOSequence(buf)) - if n != wantN || err != nil || !bytes.Equal(buf, msg) { - t.Fatalf("Readv: got (%d, %v) %q, wanted (%d, nil) %q", n, err, buf, wantN, msg) - } -} - -func TestPipeReadBlock(t *testing.T) { - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536) - defer r.DecRef(ctx) - defer w.DecRef(ctx) - - n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1))) - if n != 0 || err != linuxerr.ErrWouldBlock { - t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, linuxerr.ErrWouldBlock) - } -} - -func TestPipeWriteBlock(t *testing.T) { - const atomicIOBytes = 2 - const capacity = MinimumPipeSize - - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity) - defer r.DecRef(ctx) - defer w.DecRef(ctx) - - msg := make([]byte, capacity+1) - n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if wantN, wantErr := int64(capacity), linuxerr.ErrWouldBlock; n != wantN || err != wantErr { - t.Fatalf("Writev: got (%d, %v), wanted (%d, %v)", n, err, wantN, wantErr) - } -} - -func TestPipeWriteUntilEnd(t *testing.T) { - const atomicIOBytes = 2 - - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes) - defer r.DecRef(ctx) - defer w.DecRef(ctx) - - msg := []byte("here's some bytes") - - wDone := make(chan struct{}, 0) - rDone := make(chan struct{}, 0) - defer func() { - // Signal the reader to stop and wait until it does so. - close(wDone) - <-rDone - }() - - go func() { - defer close(rDone) - // Read from r until done is closed. - ctx := contexttest.Context(t) - buf := make([]byte, len(msg)+1) - dst := usermem.BytesIOSequence(buf) - e, ch := waiter.NewChannelEntry(nil) - r.EventRegister(&e, waiter.ReadableEvents) - defer r.EventUnregister(&e) - for { - n, err := r.Readv(ctx, dst) - dst = dst.DropFirst64(n) - if err == linuxerr.ErrWouldBlock { - select { - case <-ch: - continue - case <-wDone: - // We expect to have 1 byte left in dst since len(buf) == - // len(msg)+1. - if dst.NumBytes() != 1 || !bytes.Equal(buf[:len(msg)], msg) { - t.Errorf("Reader: got %q (%d bytes remaining), wanted %q", buf, dst.NumBytes(), msg) - } - return - } - } - if err != nil { - t.Errorf("Readv: got unexpected error %v", err) - return - } - } - }() - - src := usermem.BytesIOSequence(msg) - e, ch := waiter.NewChannelEntry(nil) - w.EventRegister(&e, waiter.WritableEvents) - defer w.EventUnregister(&e) - for src.NumBytes() != 0 { - n, err := w.Writev(ctx, src) - src = src.DropFirst64(n) - if err == linuxerr.ErrWouldBlock { - <-ch - continue - } - if err != nil { - t.Fatalf("Writev: got (%d, %v)", n, err) - } - } -} diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe_state_autogen.go b/pkg/sentry/kernel/pipe/pipe_unsafe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/sentry/kernel/process_group_list.go b/pkg/sentry/kernel/process_group_list.go new file mode 100644 index 000000000..9c493504d --- /dev/null +++ b/pkg/sentry/kernel/process_group_list.go @@ -0,0 +1,221 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type processGroupElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (processGroupElementMapper) linkerFor(elem *ProcessGroup) *ProcessGroup { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type processGroupList struct { + head *ProcessGroup + tail *ProcessGroup +} + +// Reset resets list l to the empty state. +func (l *processGroupList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *processGroupList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *processGroupList) Front() *ProcessGroup { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *processGroupList) Back() *ProcessGroup { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *processGroupList) Len() (count int) { + for e := l.Front(); e != nil; e = (processGroupElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *processGroupList) PushFront(e *ProcessGroup) { + linker := processGroupElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + processGroupElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *processGroupList) PushBack(e *ProcessGroup) { + linker := processGroupElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + processGroupElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *processGroupList) PushBackList(m *processGroupList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + processGroupElementMapper{}.linkerFor(l.tail).SetNext(m.head) + processGroupElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *processGroupList) InsertAfter(b, e *ProcessGroup) { + bLinker := processGroupElementMapper{}.linkerFor(b) + eLinker := processGroupElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + processGroupElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *processGroupList) InsertBefore(a, e *ProcessGroup) { + aLinker := processGroupElementMapper{}.linkerFor(a) + eLinker := processGroupElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + processGroupElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *processGroupList) Remove(e *ProcessGroup) { + linker := processGroupElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + processGroupElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + processGroupElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type processGroupEntry struct { + next *ProcessGroup + prev *ProcessGroup +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *processGroupEntry) Next() *ProcessGroup { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *processGroupEntry) Prev() *ProcessGroup { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *processGroupEntry) SetNext(elem *ProcessGroup) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *processGroupEntry) SetPrev(elem *ProcessGroup) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/process_group_refs.go b/pkg/sentry/kernel/process_group_refs.go new file mode 100644 index 000000000..cfd73315f --- /dev/null +++ b/pkg/sentry/kernel/process_group_refs.go @@ -0,0 +1,140 @@ +package kernel + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const ProcessGroupenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var ProcessGroupobj *ProcessGroup + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type ProcessGroupRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *ProcessGroupRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *ProcessGroupRefs) RefType() string { + return fmt.Sprintf("%T", ProcessGroupobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *ProcessGroupRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *ProcessGroupRefs) LogRefs() bool { + return ProcessGroupenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *ProcessGroupRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *ProcessGroupRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if ProcessGroupenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *ProcessGroupRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if ProcessGroupenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *ProcessGroupRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if ProcessGroupenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *ProcessGroupRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD deleted file mode 100644 index 1b82e087b..000000000 --- a/pkg/sentry/kernel/sched/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sched", - srcs = [ - "cpuset.go", - "sched.go", - ], - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "sched_test", - size = "small", - srcs = ["cpuset_test.go"], - library = ":sched", -) diff --git a/pkg/sentry/kernel/sched/cpuset_test.go b/pkg/sentry/kernel/sched/cpuset_test.go deleted file mode 100644 index 3af9f1197..000000000 --- a/pkg/sentry/kernel/sched/cpuset_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sched - -import ( - "testing" -) - -func TestNumCPUs(t *testing.T) { - for i := uint(0); i < 1024; i++ { - c := NewCPUSet(i) - for j := uint(0); j < i; j++ { - c.Set(j) - } - n := c.NumCPUs() - if n != i { - t.Errorf("got wrong number of cpus %d, want %d", n, i) - } - } -} - -func TestClearAbove(t *testing.T) { - const n = 1024 - c := NewFullCPUSet(n) - for i := uint(0); i < n; i++ { - cpu := n - i - c.ClearAbove(cpu) - if got := c.NumCPUs(); got != cpu { - t.Errorf("iteration %d: got %d cpus, wanted %d", i, got, cpu) - } - } -} diff --git a/pkg/sentry/kernel/sched/sched_state_autogen.go b/pkg/sentry/kernel/sched/sched_state_autogen.go new file mode 100644 index 000000000..9705ca79d --- /dev/null +++ b/pkg/sentry/kernel/sched/sched_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sched diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD deleted file mode 100644 index 6aa74219e..000000000 --- a/pkg/sentry/kernel/semaphore/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "semaphore", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*waiter", - "Linker": "*waiter", - }, -) - -go_library( - name = "semaphore", - srcs = [ - "semaphore.go", - "waiter_list.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/ipc", - "//pkg/sentry/kernel/time", - "//pkg/sync", - ], -) - -go_test( - name = "semaphore_test", - size = "small", - srcs = ["semaphore_test.go"], - library = ":semaphore", - deps = [ - "//pkg/abi/linux", # keep - "//pkg/context", # keep - "//pkg/errors/linuxerr", #keep - "//pkg/sentry/contexttest", # keep - "//pkg/sentry/kernel/auth", # keep - "//pkg/sentry/kernel/ipc", # keep - ], -) diff --git a/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go new file mode 100644 index 000000000..7ea96b30d --- /dev/null +++ b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go @@ -0,0 +1,202 @@ +// automatically generated by stateify. + +package semaphore + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *Registry) StateTypeName() string { + return "pkg/sentry/kernel/semaphore.Registry" +} + +func (r *Registry) StateFields() []string { + return []string{ + "reg", + "indexes", + } +} + +func (r *Registry) beforeSave() {} + +// +checklocksignore +func (r *Registry) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.reg) + stateSinkObject.Save(1, &r.indexes) +} + +func (r *Registry) afterLoad() {} + +// +checklocksignore +func (r *Registry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.reg) + stateSourceObject.Load(1, &r.indexes) +} + +func (s *Set) StateTypeName() string { + return "pkg/sentry/kernel/semaphore.Set" +} + +func (s *Set) StateFields() []string { + return []string{ + "registry", + "obj", + "opTime", + "changeTime", + "sems", + "dead", + } +} + +func (s *Set) beforeSave() {} + +// +checklocksignore +func (s *Set) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.registry) + stateSinkObject.Save(1, &s.obj) + stateSinkObject.Save(2, &s.opTime) + stateSinkObject.Save(3, &s.changeTime) + stateSinkObject.Save(4, &s.sems) + stateSinkObject.Save(5, &s.dead) +} + +func (s *Set) afterLoad() {} + +// +checklocksignore +func (s *Set) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.registry) + stateSourceObject.Load(1, &s.obj) + stateSourceObject.Load(2, &s.opTime) + stateSourceObject.Load(3, &s.changeTime) + stateSourceObject.Load(4, &s.sems) + stateSourceObject.Load(5, &s.dead) +} + +func (s *sem) StateTypeName() string { + return "pkg/sentry/kernel/semaphore.sem" +} + +func (s *sem) StateFields() []string { + return []string{ + "value", + "pid", + } +} + +func (s *sem) beforeSave() {} + +// +checklocksignore +func (s *sem) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + if !state.IsZeroValue(&s.waiters) { + state.Failf("waiters is %#v, expected zero", &s.waiters) + } + stateSinkObject.Save(0, &s.value) + stateSinkObject.Save(1, &s.pid) +} + +func (s *sem) afterLoad() {} + +// +checklocksignore +func (s *sem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.value) + stateSourceObject.Load(1, &s.pid) +} + +func (w *waiter) StateTypeName() string { + return "pkg/sentry/kernel/semaphore.waiter" +} + +func (w *waiter) StateFields() []string { + return []string{ + "waiterEntry", + "value", + "ch", + } +} + +func (w *waiter) beforeSave() {} + +// +checklocksignore +func (w *waiter) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.waiterEntry) + stateSinkObject.Save(1, &w.value) + stateSinkObject.Save(2, &w.ch) +} + +func (w *waiter) afterLoad() {} + +// +checklocksignore +func (w *waiter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.waiterEntry) + stateSourceObject.Load(1, &w.value) + stateSourceObject.Load(2, &w.ch) +} + +func (l *waiterList) StateTypeName() string { + return "pkg/sentry/kernel/semaphore.waiterList" +} + +func (l *waiterList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *waiterList) beforeSave() {} + +// +checklocksignore +func (l *waiterList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *waiterList) afterLoad() {} + +// +checklocksignore +func (l *waiterList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *waiterEntry) StateTypeName() string { + return "pkg/sentry/kernel/semaphore.waiterEntry" +} + +func (e *waiterEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *waiterEntry) beforeSave() {} + +// +checklocksignore +func (e *waiterEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *waiterEntry) afterLoad() {} + +// +checklocksignore +func (e *waiterEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*Registry)(nil)) + state.Register((*Set)(nil)) + state.Register((*sem)(nil)) + state.Register((*waiter)(nil)) + state.Register((*waiterList)(nil)) + state.Register((*waiterEntry)(nil)) +} diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go deleted file mode 100644 index 59ac92ef1..000000000 --- a/pkg/sentry/kernel/semaphore/semaphore_test.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package semaphore - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" -) - -func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} { - ch, _, err := set.executeOps(ctx, ops, 123) - if err != nil { - t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops) - } - if block { - if ch == nil { - t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops) - } - if signalled(ch) { - t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) - } - } else { - if ch != nil { - t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops) - } - } - return ch -} - -func signalled(ch chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - -func TestBasic(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{obj: &ipc.Object{ID: 123}, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 1}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -1 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -1 - ch1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - if !signalled(ch1) { - t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) - } -} - -func TestWaitForZero(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{obj: &ipc.Object{ID: 123}, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 0}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -2 - ch1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 0 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = 0 - chZero1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 0 - chZero2 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - if !signalled(ch1) { - t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set) - } - - ops[0].SemOp = -2 - executeOps(ctx, t, set, ops, false) - if !signalled(chZero1) { - t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set) - } - if !signalled(chZero2) { - t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set) - } -} - -func TestNoWait(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{obj: &ipc.Object{ID: 123}, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 1}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -2 - ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock) - } - - ops[0].SemOp = 0 - ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock) - } -} - -func TestUnregister(t *testing.T) { - ctx := contexttest.Context(t) - r := NewRegistry(auth.NewRootUserNamespace()) - set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true) - - if err != nil { - t.Fatalf("FindOrCreate() failed, err: %v", err) - } - if got := r.FindByID(set.obj.ID); got.obj.ID != set.obj.ID { - t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.obj.ID, got, set) - } - - ops := []linux.Sembuf{ - {SemOp: -1}, - } - chs := make([]chan struct{}, 0, 5) - for i := 0; i < 5; i++ { - ch := executeOps(ctx, t, set, ops, true) - chs = append(chs, ch) - } - - creds := auth.CredentialsFromContext(ctx) - if err := r.Remove(set.obj.ID, creds); err != nil { - t.Fatalf("Remove(%d) failed, err: %v", set.obj.ID, err) - } - if !set.dead { - t.Fatalf("set is not dead: %+v", set) - } - if got := r.FindByID(set.obj.ID); got != nil { - t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.obj.ID, got) - } - for i, ch := range chs { - if !signalled(ch) { - t.Fatalf("channel %d should have been signalled", i) - } - } -} diff --git a/pkg/sentry/kernel/semaphore/waiter_list.go b/pkg/sentry/kernel/semaphore/waiter_list.go new file mode 100644 index 000000000..51586d43f --- /dev/null +++ b/pkg/sentry/kernel/semaphore/waiter_list.go @@ -0,0 +1,221 @@ +package semaphore + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *waiter) *waiter { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *waiter + tail *waiter +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *waiterList) Front() *waiter { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *waiterList) Back() *waiter { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = (waiterElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *waiterList) PushFront(e *waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *waiterList) PushBack(e *waiter) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *waiterList) InsertAfter(b, e *waiter) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *waiterList) InsertBefore(a, e *waiter) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *waiterList) Remove(e *waiter) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *waiter + prev *waiter +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *waiterEntry) Next() *waiter { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *waiterEntry) Prev() *waiter { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *waiterEntry) SetNext(elem *waiter) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *waiterEntry) SetPrev(elem *waiter) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go new file mode 100644 index 000000000..d45cc0a0f --- /dev/null +++ b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go @@ -0,0 +1,38 @@ +package kernel + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sync" +) + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in seq. +// +//go:nosplit +func SeqAtomicLoadTaskGoroutineSchedInfo(seq *sync.SeqCount, ptr *TaskGoroutineSchedInfo) TaskGoroutineSchedInfo { + for { + if val, ok := SeqAtomicTryLoadTaskGoroutineSchedInfo(seq, seq.BeginRead(), ptr); ok { + return val + } + } +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in seq initiated by a call to seq.BeginRead() that returned epoch. If the +// read would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +// +//go:nosplit +func SeqAtomicTryLoadTaskGoroutineSchedInfo(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *TaskGoroutineSchedInfo) (val TaskGoroutineSchedInfo, ok bool) { + if sync.RaceEnabled { + + gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + + val = *ptr + } + ok = seq.ReadOk(epoch) + return +} diff --git a/pkg/sentry/kernel/session_list.go b/pkg/sentry/kernel/session_list.go new file mode 100644 index 000000000..f53dea401 --- /dev/null +++ b/pkg/sentry/kernel/session_list.go @@ -0,0 +1,221 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type sessionElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (sessionElementMapper) linkerFor(elem *Session) *Session { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type sessionList struct { + head *Session + tail *Session +} + +// Reset resets list l to the empty state. +func (l *sessionList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *sessionList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *sessionList) Front() *Session { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *sessionList) Back() *Session { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *sessionList) Len() (count int) { + for e := l.Front(); e != nil; e = (sessionElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *sessionList) PushFront(e *Session) { + linker := sessionElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + sessionElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *sessionList) PushBack(e *Session) { + linker := sessionElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + sessionElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *sessionList) PushBackList(m *sessionList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + sessionElementMapper{}.linkerFor(l.tail).SetNext(m.head) + sessionElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *sessionList) InsertAfter(b, e *Session) { + bLinker := sessionElementMapper{}.linkerFor(b) + eLinker := sessionElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + sessionElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *sessionList) InsertBefore(a, e *Session) { + aLinker := sessionElementMapper{}.linkerFor(a) + eLinker := sessionElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + sessionElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *sessionList) Remove(e *Session) { + linker := sessionElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + sessionElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + sessionElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type sessionEntry struct { + next *Session + prev *Session +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *sessionEntry) Next() *Session { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *sessionEntry) Prev() *Session { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *sessionEntry) SetNext(elem *Session) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *sessionEntry) SetPrev(elem *Session) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/session_refs.go b/pkg/sentry/kernel/session_refs.go new file mode 100644 index 000000000..94d6380fa --- /dev/null +++ b/pkg/sentry/kernel/session_refs.go @@ -0,0 +1,140 @@ +package kernel + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const SessionenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var Sessionobj *Session + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type SessionRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *SessionRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *SessionRefs) RefType() string { + return fmt.Sprintf("%T", Sessionobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *SessionRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *SessionRefs) LogRefs() bool { + return SessionenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *SessionRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *SessionRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if SessionenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *SessionRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if SessionenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *SessionRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if SessionenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *SessionRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD deleted file mode 100644 index 2547957ba..000000000 --- a/pkg/sentry/kernel/shm/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "shm_refs", - out = "shm_refs.go", - consts = { - "enableLogging": "true", - }, - package = "shm", - prefix = "Shm", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "Shm", - }, -) - -go_library( - name = "shm", - srcs = [ - "device.go", - "shm.go", - "shm_refs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/ipc", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/kernel/shm/shm_refs.go b/pkg/sentry/kernel/shm/shm_refs.go new file mode 100644 index 000000000..e6eed69ef --- /dev/null +++ b/pkg/sentry/kernel/shm/shm_refs.go @@ -0,0 +1,140 @@ +package shm + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const ShmenableLogging = true + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var Shmobj *Shm + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type ShmRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *ShmRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *ShmRefs) RefType() string { + return fmt.Sprintf("%T", Shmobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *ShmRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *ShmRefs) LogRefs() bool { + return ShmenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *ShmRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *ShmRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if ShmenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *ShmRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if ShmenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *ShmRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if ShmenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *ShmRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/kernel/shm/shm_state_autogen.go b/pkg/sentry/kernel/shm/shm_state_autogen.go new file mode 100644 index 000000000..9dde9122d --- /dev/null +++ b/pkg/sentry/kernel/shm/shm_state_autogen.go @@ -0,0 +1,129 @@ +// automatically generated by stateify. + +package shm + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *Registry) StateTypeName() string { + return "pkg/sentry/kernel/shm.Registry" +} + +func (r *Registry) StateFields() []string { + return []string{ + "userNS", + "reg", + "totalPages", + } +} + +func (r *Registry) beforeSave() {} + +// +checklocksignore +func (r *Registry) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.userNS) + stateSinkObject.Save(1, &r.reg) + stateSinkObject.Save(2, &r.totalPages) +} + +func (r *Registry) afterLoad() {} + +// +checklocksignore +func (r *Registry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.userNS) + stateSourceObject.Load(1, &r.reg) + stateSourceObject.Load(2, &r.totalPages) +} + +func (s *Shm) StateTypeName() string { + return "pkg/sentry/kernel/shm.Shm" +} + +func (s *Shm) StateFields() []string { + return []string{ + "ShmRefs", + "mfp", + "registry", + "size", + "effectiveSize", + "fr", + "obj", + "attachTime", + "detachTime", + "changeTime", + "creatorPID", + "lastAttachDetachPID", + "pendingDestruction", + } +} + +func (s *Shm) beforeSave() {} + +// +checklocksignore +func (s *Shm) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.ShmRefs) + stateSinkObject.Save(1, &s.mfp) + stateSinkObject.Save(2, &s.registry) + stateSinkObject.Save(3, &s.size) + stateSinkObject.Save(4, &s.effectiveSize) + stateSinkObject.Save(5, &s.fr) + stateSinkObject.Save(6, &s.obj) + stateSinkObject.Save(7, &s.attachTime) + stateSinkObject.Save(8, &s.detachTime) + stateSinkObject.Save(9, &s.changeTime) + stateSinkObject.Save(10, &s.creatorPID) + stateSinkObject.Save(11, &s.lastAttachDetachPID) + stateSinkObject.Save(12, &s.pendingDestruction) +} + +func (s *Shm) afterLoad() {} + +// +checklocksignore +func (s *Shm) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.ShmRefs) + stateSourceObject.Load(1, &s.mfp) + stateSourceObject.Load(2, &s.registry) + stateSourceObject.Load(3, &s.size) + stateSourceObject.Load(4, &s.effectiveSize) + stateSourceObject.Load(5, &s.fr) + stateSourceObject.Load(6, &s.obj) + stateSourceObject.Load(7, &s.attachTime) + stateSourceObject.Load(8, &s.detachTime) + stateSourceObject.Load(9, &s.changeTime) + stateSourceObject.Load(10, &s.creatorPID) + stateSourceObject.Load(11, &s.lastAttachDetachPID) + stateSourceObject.Load(12, &s.pendingDestruction) +} + +func (r *ShmRefs) StateTypeName() string { + return "pkg/sentry/kernel/shm.ShmRefs" +} + +func (r *ShmRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *ShmRefs) beforeSave() {} + +// +checklocksignore +func (r *ShmRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *ShmRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func init() { + state.Register((*Registry)(nil)) + state.Register((*Shm)(nil)) + state.Register((*ShmRefs)(nil)) +} diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD deleted file mode 100644 index 4180ca28e..000000000 --- a/pkg/sentry/kernel/signalfd/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "signalfd", - srcs = ["signalfd.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/signalfd/signalfd_state_autogen.go b/pkg/sentry/kernel/signalfd/signalfd_state_autogen.go new file mode 100644 index 000000000..ad2622f27 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/signalfd_state_autogen.go @@ -0,0 +1,39 @@ +// automatically generated by stateify. + +package signalfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *SignalOperations) StateTypeName() string { + return "pkg/sentry/kernel/signalfd.SignalOperations" +} + +func (s *SignalOperations) StateFields() []string { + return []string{ + "target", + "mask", + } +} + +func (s *SignalOperations) beforeSave() {} + +// +checklocksignore +func (s *SignalOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.target) + stateSinkObject.Save(1, &s.mask) +} + +func (s *SignalOperations) afterLoad() {} + +// +checklocksignore +func (s *SignalOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.target) + stateSourceObject.Load(1, &s.mask) +} + +func init() { + state.Register((*SignalOperations)(nil)) +} diff --git a/pkg/sentry/kernel/socket_list.go b/pkg/sentry/kernel/socket_list.go new file mode 100644 index 000000000..2b2f5055e --- /dev/null +++ b/pkg/sentry/kernel/socket_list.go @@ -0,0 +1,221 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type socketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (socketElementMapper) linkerFor(elem *SocketRecordVFS1) *SocketRecordVFS1 { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type socketList struct { + head *SocketRecordVFS1 + tail *SocketRecordVFS1 +} + +// Reset resets list l to the empty state. +func (l *socketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *socketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *socketList) Front() *SocketRecordVFS1 { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *socketList) Back() *SocketRecordVFS1 { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *socketList) Len() (count int) { + for e := l.Front(); e != nil; e = (socketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *socketList) PushFront(e *SocketRecordVFS1) { + linker := socketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + socketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *socketList) PushBack(e *SocketRecordVFS1) { + linker := socketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + socketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *socketList) PushBackList(m *socketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + socketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + socketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *socketList) InsertAfter(b, e *SocketRecordVFS1) { + bLinker := socketElementMapper{}.linkerFor(b) + eLinker := socketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + socketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *socketList) InsertBefore(a, e *SocketRecordVFS1) { + aLinker := socketElementMapper{}.linkerFor(a) + eLinker := socketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + socketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *socketList) Remove(e *SocketRecordVFS1) { + linker := socketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + socketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + socketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type socketEntry struct { + next *SocketRecordVFS1 + prev *SocketRecordVFS1 +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *socketEntry) Next() *SocketRecordVFS1 { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *socketEntry) Prev() *SocketRecordVFS1 { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *socketEntry) SetNext(elem *SocketRecordVFS1) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *socketEntry) SetPrev(elem *SocketRecordVFS1) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/table_test.go b/pkg/sentry/kernel/table_test.go deleted file mode 100644 index 32cf47e05..000000000 --- a/pkg/sentry/kernel/table_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/sentry/arch" -) - -const ( - maxTestSyscall = 1000 -) - -func createSyscallTable() *SyscallTable { - m := make(map[uintptr]Syscall) - for i := uintptr(0); i <= maxTestSyscall; i++ { - j := i - m[i] = Syscall{ - Fn: func(*Task, arch.SyscallArguments) (uintptr, *SyscallControl, error) { - return j, nil, nil - }, - } - } - - s := &SyscallTable{ - OS: abi.Linux, - Arch: arch.AMD64, - Table: m, - } - - RegisterSyscallTable(s) - return s -} - -func TestTable(t *testing.T) { - table := createSyscallTable() - defer func() { - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} - }() - - // Go through all functions and check that they return the right value. - for i := uintptr(0); i < maxTestSyscall; i++ { - fn := table.Lookup(i) - if fn == nil { - t.Errorf("Syscall %v is set to nil", i) - continue - } - - v, _, _ := fn(nil, arch.SyscallArguments{}) - if v != i { - t.Errorf("Wrong return value for syscall %v: expected %v, got %v", i, i, v) - } - } - - // Check that values outside the range return nil. - for i := uintptr(maxTestSyscall + 1); i < maxTestSyscall+100; i++ { - fn := table.Lookup(i) - if fn != nil { - t.Errorf("Syscall %v is not nil: %v", i, fn) - continue - } - } -} - -func BenchmarkTableLookup(b *testing.B) { - table := createSyscallTable() - - b.ResetTimer() - - j := uintptr(0) - for i := 0; i < b.N; i++ { - table.Lookup(j) - j = (j + 1) % 310 - } - - b.StopTimer() - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} -} - -func BenchmarkTableMapLookup(b *testing.B) { - table := createSyscallTable() - - b.ResetTimer() - - j := uintptr(0) - for i := 0; i < b.N; i++ { - table.mapLookup(j) - j = (j + 1) % 310 - } - - b.StopTimer() - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} -} diff --git a/pkg/sentry/kernel/task_list.go b/pkg/sentry/kernel/task_list.go new file mode 100644 index 000000000..66df9ac45 --- /dev/null +++ b/pkg/sentry/kernel/task_list.go @@ -0,0 +1,221 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type taskElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (taskElementMapper) linkerFor(elem *Task) *Task { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type taskList struct { + head *Task + tail *Task +} + +// Reset resets list l to the empty state. +func (l *taskList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *taskList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *taskList) Front() *Task { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *taskList) Back() *Task { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *taskList) Len() (count int) { + for e := l.Front(); e != nil; e = (taskElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *taskList) PushFront(e *Task) { + linker := taskElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + taskElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *taskList) PushBack(e *Task) { + linker := taskElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + taskElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *taskList) PushBackList(m *taskList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + taskElementMapper{}.linkerFor(l.tail).SetNext(m.head) + taskElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *taskList) InsertAfter(b, e *Task) { + bLinker := taskElementMapper{}.linkerFor(b) + eLinker := taskElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + taskElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *taskList) InsertBefore(a, e *Task) { + aLinker := taskElementMapper{}.linkerFor(a) + eLinker := taskElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + taskElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *taskList) Remove(e *Task) { + linker := taskElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + taskElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + taskElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type taskEntry struct { + next *Task + prev *Task +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *taskEntry) Next() *Task { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *taskEntry) Prev() *Task { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *taskEntry) SetNext(elem *Task) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *taskEntry) SetPrev(elem *Task) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/task_test.go b/pkg/sentry/kernel/task_test.go deleted file mode 100644 index cfcde9a7a..000000000 --- a/pkg/sentry/kernel/task_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/kernel/sched" -) - -func TestTaskCPU(t *testing.T) { - for _, test := range []struct { - mask sched.CPUSet - tid ThreadID - cpu int32 - }{ - { - mask: []byte{0xff}, - tid: 1, - cpu: 0, - }, - { - mask: []byte{0xff}, - tid: 10, - cpu: 1, - }, - { - // more than 8 cpus. - mask: []byte{0xff, 0xff}, - tid: 10, - cpu: 9, - }, - { - // missing the first cpu. - mask: []byte{0xfe}, - tid: 1, - cpu: 1, - }, - { - mask: []byte{0xfe}, - tid: 10, - cpu: 3, - }, - { - // missing the fifth cpu. - mask: []byte{0xef}, - tid: 10, - cpu: 2, - }, - } { - assigned := assignCPU(test.mask, test.tid) - if test.cpu != assigned { - t.Errorf("assignCPU(%v, %v) got %v, want %v", test.mask, test.tid, assigned, test.cpu) - } - } - -} diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD deleted file mode 100644 index e293d9a0f..000000000 --- a/pkg/sentry/kernel/time/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "time", - srcs = [ - "context.go", - "tcpip.go", - "time.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sync", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/time/time_state_autogen.go b/pkg/sentry/kernel/time/time_state_autogen.go new file mode 100644 index 000000000..855751953 --- /dev/null +++ b/pkg/sentry/kernel/time/time_state_autogen.go @@ -0,0 +1,103 @@ +// automatically generated by stateify. + +package time + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *Time) StateTypeName() string { + return "pkg/sentry/kernel/time.Time" +} + +func (t *Time) StateFields() []string { + return []string{ + "ns", + } +} + +func (t *Time) beforeSave() {} + +// +checklocksignore +func (t *Time) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.ns) +} + +func (t *Time) afterLoad() {} + +// +checklocksignore +func (t *Time) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.ns) +} + +func (s *Setting) StateTypeName() string { + return "pkg/sentry/kernel/time.Setting" +} + +func (s *Setting) StateFields() []string { + return []string{ + "Enabled", + "Next", + "Period", + } +} + +func (s *Setting) beforeSave() {} + +// +checklocksignore +func (s *Setting) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Enabled) + stateSinkObject.Save(1, &s.Next) + stateSinkObject.Save(2, &s.Period) +} + +func (s *Setting) afterLoad() {} + +// +checklocksignore +func (s *Setting) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Enabled) + stateSourceObject.Load(1, &s.Next) + stateSourceObject.Load(2, &s.Period) +} + +func (t *Timer) StateTypeName() string { + return "pkg/sentry/kernel/time.Timer" +} + +func (t *Timer) StateFields() []string { + return []string{ + "clock", + "listener", + "setting", + "paused", + } +} + +func (t *Timer) beforeSave() {} + +// +checklocksignore +func (t *Timer) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.clock) + stateSinkObject.Save(1, &t.listener) + stateSinkObject.Save(2, &t.setting) + stateSinkObject.Save(3, &t.paused) +} + +func (t *Timer) afterLoad() {} + +// +checklocksignore +func (t *Timer) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.clock) + stateSourceObject.Load(1, &t.listener) + stateSourceObject.Load(2, &t.setting) + stateSourceObject.Load(3, &t.paused) +} + +func init() { + state.Register((*Time)(nil)) + state.Register((*Setting)(nil)) + state.Register((*Timer)(nil)) +} diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go deleted file mode 100644 index b6039505a..000000000 --- a/pkg/sentry/kernel/timekeeper_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - sentrytime "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/sentry/usage" -) - -// mockClocks is a sentrytime.Clocks that simply returns the times in the -// struct. -type mockClocks struct { - monotonic int64 - realtime int64 -} - -// Update implements sentrytime.Clocks.Update. It does nothing. -func (*mockClocks) Update() (monotonicParams sentrytime.Parameters, monotonicOk bool, realtimeParam sentrytime.Parameters, realtimeOk bool) { - return -} - -// Update implements sentrytime.Clocks.GetTime. -func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) { - switch id { - case sentrytime.Monotonic: - return c.monotonic, nil - case sentrytime.Realtime: - return c.realtime, nil - default: - return 0, linuxerr.EINVAL - } -} - -// stateTestClocklessTimekeeper returns a test Timekeeper which has not had -// SetClocks called. -func stateTestClocklessTimekeeper(tb testing.TB) *Timekeeper { - ctx := contexttest.Context(tb) - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - fr, err := mfp.MemoryFile().Allocate(hostarch.PageSize, usage.Anonymous) - if err != nil { - tb.Fatalf("failed to allocate memory: %v", err) - } - return &Timekeeper{ - params: NewVDSOParamPage(mfp, fr), - } -} - -func stateTestTimekeeper(tb testing.TB) *Timekeeper { - t := stateTestClocklessTimekeeper(tb) - t.SetClocks(sentrytime.NewCalibratedClocks()) - return t -} - -// TestTimekeeperMonotonicZero tests that monotonic time starts at zero. -func TestTimekeeperMonotonicZero(t *testing.T) { - c := &mockClocks{ - monotonic: 100000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.SetClocks(c) - defer tk.Destroy() - - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 0 { - t.Errorf("GetTime got %d want 0", now) - } - - c.monotonic += 10 - - now, err = tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 10 { - t.Errorf("GetTime got %d want 10", now) - } -} - -// TestTimekeeperMonotonicJumpForward tests that monotonic time jumps forward -// after restore. -func TestTimekeeperMonotonicForward(t *testing.T) { - c := &mockClocks{ - monotonic: 900000, - realtime: 600000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.restored = make(chan struct{}) - tk.saveMonotonic = 100000 - tk.saveRealtime = 400000 - tk.SetClocks(c) - defer tk.Destroy() - - // The monotonic clock should jump ahead by 200000 to 300000. - // - // The new system monotonic time (900000) is irrelevant to what the app - // sees. - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 300000 { - t.Errorf("GetTime got %d want 300000", now) - } -} - -// TestTimekeeperMonotonicJumpBackwards tests that monotonic time does not jump -// backwards when realtime goes backwards. -func TestTimekeeperMonotonicJumpBackwards(t *testing.T) { - c := &mockClocks{ - monotonic: 900000, - realtime: 400000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.restored = make(chan struct{}) - tk.saveMonotonic = 100000 - tk.saveRealtime = 600000 - tk.SetClocks(c) - defer tk.Destroy() - - // The monotonic clock should remain at 100000. - // - // The new system monotonic time (900000) is irrelevant to what the app - // sees and we don't want to jump the monotonic clock backwards like - // realtime did. - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 100000 { - t.Errorf("GetTime got %d want 100000", now) - } -} diff --git a/pkg/sentry/kernel/uncaught_signal.proto b/pkg/sentry/kernel/uncaught_signal.proto deleted file mode 100644 index 0bdb062cb..000000000 --- a/pkg/sentry/kernel/uncaught_signal.proto +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "pkg/sentry/arch/registers.proto"; - -message UncaughtSignal { - // Thread ID. - int32 tid = 1; - - // Process ID. - int32 pid = 2; - - // Registers at the time of the fault or signal. - Registers registers = 3; - - // Signal number. - int32 signal_number = 4; - - // The memory location which caused the fault (set if applicable, 0 - // otherwise). This will be set for SIGILL, SIGFPE, SIGSEGV, and SIGBUS. - uint64 fault_addr = 5; -} diff --git a/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go new file mode 100644 index 000000000..54955f061 --- /dev/null +++ b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go @@ -0,0 +1,193 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/sentry/kernel/uncaught_signal.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + registers_go_proto "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type UncaughtSignal struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"` + Pid int32 `protobuf:"varint,2,opt,name=pid,proto3" json:"pid,omitempty"` + Registers *registers_go_proto.Registers `protobuf:"bytes,3,opt,name=registers,proto3" json:"registers,omitempty"` + SignalNumber int32 `protobuf:"varint,4,opt,name=signal_number,json=signalNumber,proto3" json:"signal_number,omitempty"` + FaultAddr uint64 `protobuf:"varint,5,opt,name=fault_addr,json=faultAddr,proto3" json:"fault_addr,omitempty"` +} + +func (x *UncaughtSignal) Reset() { + *x = UncaughtSignal{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_kernel_uncaught_signal_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UncaughtSignal) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UncaughtSignal) ProtoMessage() {} + +func (x *UncaughtSignal) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_kernel_uncaught_signal_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UncaughtSignal.ProtoReflect.Descriptor instead. +func (*UncaughtSignal) Descriptor() ([]byte, []int) { + return file_pkg_sentry_kernel_uncaught_signal_proto_rawDescGZIP(), []int{0} +} + +func (x *UncaughtSignal) GetTid() int32 { + if x != nil { + return x.Tid + } + return 0 +} + +func (x *UncaughtSignal) GetPid() int32 { + if x != nil { + return x.Pid + } + return 0 +} + +func (x *UncaughtSignal) GetRegisters() *registers_go_proto.Registers { + if x != nil { + return x.Registers + } + return nil +} + +func (x *UncaughtSignal) GetSignalNumber() int32 { + if x != nil { + return x.SignalNumber + } + return 0 +} + +func (x *UncaughtSignal) GetFaultAddr() uint64 { + if x != nil { + return x.FaultAddr + } + return 0 +} + +var File_pkg_sentry_kernel_uncaught_signal_proto protoreflect.FileDescriptor + +var file_pkg_sentry_kernel_uncaught_signal_proto_rawDesc = []byte{ + 0x0a, 0x27, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2f, 0x6b, 0x65, 0x72, + 0x6e, 0x65, 0x6c, 0x2f, 0x75, 0x6e, 0x63, 0x61, 0x75, 0x67, 0x68, 0x74, 0x5f, 0x73, 0x69, 0x67, + 0x6e, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x67, 0x76, 0x69, 0x73, 0x6f, + 0x72, 0x1a, 0x1f, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2f, 0x61, 0x72, + 0x63, 0x68, 0x2f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x22, 0xa9, 0x01, 0x0a, 0x0e, 0x55, 0x6e, 0x63, 0x61, 0x75, 0x67, 0x68, 0x74, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x03, 0x74, 0x69, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x70, 0x69, 0x64, 0x12, 0x2f, 0x0a, 0x09, 0x72, 0x65, 0x67, + 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x67, + 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x52, + 0x09, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x5f, 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, + 0x1d, 0x0a, 0x0a, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x09, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x41, 0x64, 0x64, 0x72, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_sentry_kernel_uncaught_signal_proto_rawDescOnce sync.Once + file_pkg_sentry_kernel_uncaught_signal_proto_rawDescData = file_pkg_sentry_kernel_uncaught_signal_proto_rawDesc +) + +func file_pkg_sentry_kernel_uncaught_signal_proto_rawDescGZIP() []byte { + file_pkg_sentry_kernel_uncaught_signal_proto_rawDescOnce.Do(func() { + file_pkg_sentry_kernel_uncaught_signal_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_sentry_kernel_uncaught_signal_proto_rawDescData) + }) + return file_pkg_sentry_kernel_uncaught_signal_proto_rawDescData +} + +var file_pkg_sentry_kernel_uncaught_signal_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_pkg_sentry_kernel_uncaught_signal_proto_goTypes = []interface{}{ + (*UncaughtSignal)(nil), // 0: gvisor.UncaughtSignal + (*registers_go_proto.Registers)(nil), // 1: gvisor.Registers +} +var file_pkg_sentry_kernel_uncaught_signal_proto_depIdxs = []int32{ + 1, // 0: gvisor.UncaughtSignal.registers:type_name -> gvisor.Registers + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_pkg_sentry_kernel_uncaught_signal_proto_init() } +func file_pkg_sentry_kernel_uncaught_signal_proto_init() { + if File_pkg_sentry_kernel_uncaught_signal_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_sentry_kernel_uncaught_signal_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UncaughtSignal); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_sentry_kernel_uncaught_signal_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_sentry_kernel_uncaught_signal_proto_goTypes, + DependencyIndexes: file_pkg_sentry_kernel_uncaught_signal_proto_depIdxs, + MessageInfos: file_pkg_sentry_kernel_uncaught_signal_proto_msgTypes, + }.Build() + File_pkg_sentry_kernel_uncaught_signal_proto = out.File + file_pkg_sentry_kernel_uncaught_signal_proto_rawDesc = nil + file_pkg_sentry_kernel_uncaught_signal_proto_goTypes = nil + file_pkg_sentry_kernel_uncaught_signal_proto_depIdxs = nil +} diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD deleted file mode 100644 index 21b0d1595..000000000 --- a/pkg/sentry/limits/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "limits", - srcs = [ - "context.go", - "limits.go", - "linux.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "limits_test", - size = "small", - srcs = [ - "limits_test.go", - ], - library = ":limits", - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/sentry/limits/limits_state_autogen.go b/pkg/sentry/limits/limits_state_autogen.go new file mode 100644 index 000000000..038124f76 --- /dev/null +++ b/pkg/sentry/limits/limits_state_autogen.go @@ -0,0 +1,65 @@ +// automatically generated by stateify. + +package limits + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *Limit) StateTypeName() string { + return "pkg/sentry/limits.Limit" +} + +func (l *Limit) StateFields() []string { + return []string{ + "Cur", + "Max", + } +} + +func (l *Limit) beforeSave() {} + +// +checklocksignore +func (l *Limit) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Cur) + stateSinkObject.Save(1, &l.Max) +} + +func (l *Limit) afterLoad() {} + +// +checklocksignore +func (l *Limit) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Cur) + stateSourceObject.Load(1, &l.Max) +} + +func (l *LimitSet) StateTypeName() string { + return "pkg/sentry/limits.LimitSet" +} + +func (l *LimitSet) StateFields() []string { + return []string{ + "data", + } +} + +func (l *LimitSet) beforeSave() {} + +// +checklocksignore +func (l *LimitSet) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.data) +} + +func (l *LimitSet) afterLoad() {} + +// +checklocksignore +func (l *LimitSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.data) +} + +func init() { + state.Register((*Limit)(nil)) + state.Register((*LimitSet)(nil)) +} diff --git a/pkg/sentry/limits/limits_test.go b/pkg/sentry/limits/limits_test.go deleted file mode 100644 index 0ee877be4..000000000 --- a/pkg/sentry/limits/limits_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package limits - -import ( - "testing" - - "golang.org/x/sys/unix" -) - -func TestSet(t *testing.T) { - testCases := []struct { - limit Limit - privileged bool - expectedErr error - }{ - {limit: Limit{Cur: 50, Max: 50}, privileged: false, expectedErr: nil}, - {limit: Limit{Cur: 20, Max: 50}, privileged: false, expectedErr: nil}, - {limit: Limit{Cur: 20, Max: 60}, privileged: false, expectedErr: unix.EPERM}, - {limit: Limit{Cur: 60, Max: 50}, privileged: false, expectedErr: unix.EINVAL}, - {limit: Limit{Cur: 11, Max: 10}, privileged: false, expectedErr: unix.EINVAL}, - {limit: Limit{Cur: 20, Max: 60}, privileged: true, expectedErr: nil}, - } - - ls := NewLimitSet() - for _, tc := range testCases { - if _, err := ls.Set(1, tc.limit, tc.privileged); err != tc.expectedErr { - t.Fatalf("Tried to set Limit to %+v and privilege %t: got %v, wanted %v", tc.limit, tc.privileged, err, tc.expectedErr) - } - } - -} diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD deleted file mode 100644 index 560a0f33c..000000000 --- a/pkg/sentry/loader/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "loader", - srcs = [ - "elf.go", - "interpreter.go", - "loader.go", - "vdso.go", - "vdso_state.go", - ], - marshal = True, - marshal_debug = True, - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/abi/linux/errno", - "//pkg/context", - "//pkg/cpuid", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/limits", - "//pkg/sentry/loader/vdsodata", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/loader/loader_abi_autogen_unsafe.go b/pkg/sentry/loader/loader_abi_autogen_unsafe.go new file mode 100644 index 000000000..618cb60fc --- /dev/null +++ b/pkg/sentry/loader/loader_abi_autogen_unsafe.go @@ -0,0 +1,7 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package loader + +import ( +) + diff --git a/pkg/sentry/loader/loader_state_autogen.go b/pkg/sentry/loader/loader_state_autogen.go new file mode 100644 index 000000000..84f3fbb19 --- /dev/null +++ b/pkg/sentry/loader/loader_state_autogen.go @@ -0,0 +1,97 @@ +// automatically generated by stateify. + +package loader + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (v *VDSO) StateTypeName() string { + return "pkg/sentry/loader.VDSO" +} + +func (v *VDSO) StateFields() []string { + return []string{ + "ParamPage", + "vdso", + "os", + "arch", + "phdrs", + } +} + +func (v *VDSO) beforeSave() {} + +// +checklocksignore +func (v *VDSO) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + var phdrsValue []elfProgHeader + phdrsValue = v.savePhdrs() + stateSinkObject.SaveValue(4, phdrsValue) + stateSinkObject.Save(0, &v.ParamPage) + stateSinkObject.Save(1, &v.vdso) + stateSinkObject.Save(2, &v.os) + stateSinkObject.Save(3, &v.arch) +} + +func (v *VDSO) afterLoad() {} + +// +checklocksignore +func (v *VDSO) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.ParamPage) + stateSourceObject.Load(1, &v.vdso) + stateSourceObject.Load(2, &v.os) + stateSourceObject.Load(3, &v.arch) + stateSourceObject.LoadValue(4, new([]elfProgHeader), func(y interface{}) { v.loadPhdrs(y.([]elfProgHeader)) }) +} + +func (e *elfProgHeader) StateTypeName() string { + return "pkg/sentry/loader.elfProgHeader" +} + +func (e *elfProgHeader) StateFields() []string { + return []string{ + "Type", + "Flags", + "Off", + "Vaddr", + "Paddr", + "Filesz", + "Memsz", + "Align", + } +} + +func (e *elfProgHeader) beforeSave() {} + +// +checklocksignore +func (e *elfProgHeader) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Type) + stateSinkObject.Save(1, &e.Flags) + stateSinkObject.Save(2, &e.Off) + stateSinkObject.Save(3, &e.Vaddr) + stateSinkObject.Save(4, &e.Paddr) + stateSinkObject.Save(5, &e.Filesz) + stateSinkObject.Save(6, &e.Memsz) + stateSinkObject.Save(7, &e.Align) +} + +func (e *elfProgHeader) afterLoad() {} + +// +checklocksignore +func (e *elfProgHeader) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Type) + stateSourceObject.Load(1, &e.Flags) + stateSourceObject.Load(2, &e.Off) + stateSourceObject.Load(3, &e.Vaddr) + stateSourceObject.Load(4, &e.Paddr) + stateSourceObject.Load(5, &e.Filesz) + stateSourceObject.Load(6, &e.Memsz) + stateSourceObject.Load(7, &e.Align) +} + +func init() { + state.Register((*VDSO)(nil)) + state.Register((*elfProgHeader)(nil)) +} diff --git a/pkg/sentry/loader/vdsodata/BUILD b/pkg/sentry/loader/vdsodata/BUILD deleted file mode 100644 index 119199f97..000000000 --- a/pkg/sentry/loader/vdsodata/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_add_tags", "go_embed_data", "go_library") - -package(licenses = ["notice"]) - -go_embed_data( - name = "vdso_bin", - src = "//vdso:vdso.so", - package = "vdsodata", - var = "Binary", -) - -[ - # Generate multiple tagged files. Note that the contents of all files - # will be the same (i.e. vdso_arm64.go will contain the amd64 vdso), but - # the build tags will ensure only one is selected. When we generate the - # "Go" branch, we select all archiecture files from the relevant build. - # This is a hack around some limitations for "out" being a configurable - # attribute and selects for srcs. See also tools/go_branch.sh. - go_add_tags( - name = "vdso_%s" % arch, - src = ":vdso_bin", - out = "vdso_%s.go" % arch, - go_tags = [arch], - ) - for arch in ("amd64", "arm64") -] - -go_library( - name = "vdsodata", - srcs = [ - "vdsodata.go", - ":vdso_amd64", - ":vdso_arm64", - ], - marshal = False, - stateify = False, - visibility = ["//pkg/sentry:internal"], -) diff --git a/pkg/sentry/loader/vdsodata/vdso_amd64.go b/pkg/sentry/loader/vdsodata/vdso_amd64.go new file mode 100644 index 000000000..ff53489f8 --- /dev/null +++ b/pkg/sentry/loader/vdsodata/vdso_amd64.go @@ -0,0 +1,7 @@ +// +build amd64 + +// Generated by go_embed_data for //pkg/sentry/loader/vdsodata:vdso_bin. DO NOT EDIT. + +package vdsodata + +var Binary = []byte("ELF\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00>\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00X\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x008\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00p\xff\xff\xff\xff\xff\xae\x00\x00\x00\x00\x00\x00\xae\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0\x00\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\xe0p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00T\x00\x00\x00\x00\x00\x00Tp\xff\xff\xff\xff\xffTp\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\x00`\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00P\xe5td\x00\x00\x00\xb4\x00\x00\x00\x00\x00\x00\xb4p\xff\xff\xff\xff\xff\xb4p\xff\xff\xff\xff\xffD\x00\x00\x00\x00\x00\x00\x00D\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\"\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\"\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\"\x00\x00\xf0p\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x005\x00\x00\x00\x00\x00\xf0p\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00A\x00\x00\x00\"\x00\x00\x80p\xff\xff\xff\xff\xffb\x00\x00\x00\x00\x00\x00\x00N\x00\x00\x00\x00\x00\x80p\xff\xff\xff\xff\xffb\x00\x00\x00\x00\x00\x00\x00b\x00\x00\x00\"\x00\x00@p\xff\xff\xff\xff\xff8\x00\x00\x00\x00\x00\x00\x00p\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff8\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00linux-vdso.so.1\x00LINUX_2.6\x00getcpu\x00__vdso_getcpu\x00time\x00__vdso_time\x00gettimeofday\x00__vdso_gettimeofday\x00clock_gettime\x00__vdso_clock_gettime\x00__kernel_rt_sigreturn\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa1\xbf\xee
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf6u\xae\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00GNU\x00\x00\x00\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00GNU\x00gold 1.16\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00GNU\x00g\xf3E\xee\xe6C\n\x9a<\x8b\xa5L\xddU\x9fN;@\x00\x00\x00\x00\x00\x00|\x00\x00\\\x00\x00\x00\x8c\x00\x00t\x00\x00\x00\xcc\x00\x00\x8c\x00\x00\x00<
\x00\x00\xb4\x00\x00\x00l
\x00\x00\xd4\x00\x00\x00|
\x00\x00\xec\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00zR\x00x\x90\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x004\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$\x00\x00\x00L\x00\x00\x008\x00\x00b\x00\x00\x00\x00E\x86D\x83D0RAA\x00\x00\x00t\x00\x00\x00\x80\x00\x00&\x00\x00\x00\x00E\x83G XA\x00\x00\x00\x00\x94\x00\x00\x00\x90\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xac\x00\x00\x00\x88\x00\x00\xbb\x00\x00\x00\x00E\x83~\nmJ\x00\x00\x00\xcc\x00\x00\x00(
\x00\x00\xbe\x00\x00\x00\x00E\x83~\nmM\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00hp\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x9b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\xff\xffo\x00\x00\x00\x00p\xff\xff\xff\xff\xff\xfc\xff\xffo\x00\x00\x00\x00p\xff\xff\xff\xff\xff\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf3\xfa\xb8\x00\x00\x00\xc3@\x00\xf3\xfa\x83\xfft\x83\xfft\x85\xfft\xb8\xe4\x00\x00\x00\xc3fD\x00\x00H\x89\xf7\xe9\x88\x00\x00\x84\x00\x00\x00\x00\x00H\x89\xf7\xe9\xb8\x00\x00\x00\x84\x00\x00\x00\x00\x00\xf3\xfaUH\x89\xf5SH\x83\xecH\x85\xfft:H\x89\xfbH\x89\xe7\xe8\x93\x00\x00\x00\x85\xc0u:H\x8b$H\x8bL$H\xba\xcf\xf7S㥛\xc4 H\x89H\x89\xc8H\xc1\xf9?H\xf7\xeaH\xc1\xfaH)\xcaH\x89S1\xc0H\x85\xedtH\xc7E\x00\x00\x00\x00\x00H\x83\xc4[]\xc3ff.\x84\x00\x00\x00\x00\x00\x00\xf3\xfaSH\x89\xfbH\x83\xecH\x89\xe7\xe8,\x00\x00\x00H\x8b$H\x85\xdbtH\x89H\x83\xc4[\xc3f.\x84\x00\x00\x00\x00\x00\xf3\xfa\xb85\x00\x00H\x98Ð\x90\xf3\xfaSH\x89\xfeH\x8d
\xc1\xde\xff\xffH\x8b9\x83\xe7\xfeL\x8bQ(L\x8bA8H\x8bY0L\x8bY@Lc\xcf\xae\xe81H\x8b9L9\xcfu\xddM\x85\xd2tv\x89\xc0H\xc1\xe2 H \xc21\xc0H9\xd3+H\xb8\x00\x00\x00\x00\x00ʚ;H\x89\xd11\xd2I\xf7\xf3H)\xd9H\x89\xcfH\xc1\xff?H\xaf\xf8H\xf7\xe1H\xfaH\xac\xd0 H\xb9SZ\x9b\xa0/\xb8D\x00I\xc0[L\x89\xc2H\xc1\xea H\x89\xd0H\xf7\xe11\xc0H\xc1\xeaH\x89Hi\xd2\x00ʚ;I)\xd0L\x89F\xc3\x84\x00\x00\x00\x00\x00\xb8\xe4\x00\x00\x001\xff[\xc3D\x00\x00\xf3\xfaSH\x89\xfeH\x8d
\xde\xff\xffH\x8b9\x83\xe7\xfeL\x8bQL\x8bAH\x8bYL\x8bY Lc\xcf\xae\xe81H\x8b9L9\xcfu\xddM\x85\xd2tv\x89\xc0H\xc1\xe2 H \xc21\xc0H9\xd3+H\xb8\x00\x00\x00\x00\x00ʚ;H\x89\xd11\xd2I\xf7\xf3H)\xd9H\x89\xcfH\xc1\xff?H\xaf\xf8H\xf7\xe1H\xfaH\xac\xd0 H\xb9SZ\x9b\xa0/\xb8D\x00I\xc0[L\x89\xc2H\xc1\xea H\x89\xd0H\xf7\xe11\xc0H\xc1\xeaH\x89Hi\xd2\x00ʚ;I)\xd0L\x89F\xc3\x84\x00\x00\x00\x00\x00\xb8\xe4\x00\x00\x00\xbf\x00\x00\x00[\xc3\x00GCC: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff\xbb\x00\x00\x00\x00\x00\x00\x009\x00\x00\x00\x00\x00\xf0p\xff\xff\xff\xff\xff\xbe\x00\x00\x00\x00\x00\x00\x00]\x00\x00\x00 \x00\xe0p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00f\x00\x00\x00\x00\x00\xf1\xff\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00s\x00\x00\x00\x00\x00\xf1\xff\x00\xf0o\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00{\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\"\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x8c\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x9a\x00\x00\x00\"\x00\x00\xf0p\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00\x9f\x00\x00\x00\x00\x00\xf0p\xff\xff\xff\xff\xff&\x00\x00\x00\x00\x00\x00\x00\xab\x00\x00\x00\"\x00\x00\x80p\xff\xff\xff\xff\xffb\x00\x00\x00\x00\x00\x00\x00\xb8\x00\x00\x00\x00\x00\x80p\xff\xff\xff\xff\xffb\x00\x00\x00\x00\x00\x00\x00\xcc\x00\x00\x00\"\x00\x00@p\xff\xff\xff\xff\xff8\x00\x00\x00\x00\x00\x00\x00\xda\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00vdso.cc\x00vdso_time.cc\x00_ZN4vdso13ClockRealtimeEP8timespec\x00_ZN4vdso14ClockMonotonicEP8timespec\x00_DYNAMIC\x00VDSO_PRELINK\x00_params\x00LINUX_2.6\x00getcpu\x00__vdso_getcpu\x00time\x00__vdso_time\x00gettimeofday\x00__vdso_gettimeofday\x00clock_gettime\x00_GLOBAL_OFFSET_TABLE_\x00__vdso_clock_gettime\x00__kernel_rt_sigreturn\x00\x00.text\x00.comment\x00.bss\x00.dynstr\x00.eh_frame_hdr\x00.gnu.version\x00.dynsym\x00.hash\x00.note\x00.eh_frame\x00.gnu.version_d\x00.dynamic\x00.shstrtab\x00.strtab\x00.symtab\x00.data\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff \x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00hp\xff\xff\xff\xff\xffh\x00\x00\x00\x00\x00\x00\x9b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00+\x00\x00\x00\xff\xff\xffo\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00V\x00\x00\x00\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00F\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Tp\xff\xff\xff\xff\xffT\x00\x00\x00\x00\x00\x00`\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\x00\x00\x00\x00\x00\x00\x00\xb4p\xff\xff\xff\xff\xff\xb4\x00\x00\x00\x00\x00\x00D\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00L\x00\x00\x00\x00\x00p\x00\x00\x00\x00\x00\x00\x00\xf8p\xff\xff\xff\xff\xff\xf8\x00\x00\x00\x00\x00\x00\xe8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\xe0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x88\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0p\xff\xff\xff\xff\xff\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff0\x00\x00\x00\x00\x00\x00~\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaep\xff\xff\xff\xff\xff\xae\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xae\x00\x00\x00\x00\x00\x00+\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0\x00\x00\x00\x00\x00\x00\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xc3\x00\x00\x00\x00\x00\x00\x8e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") diff --git a/pkg/sentry/loader/vdsodata/vdso_arm64.go b/pkg/sentry/loader/vdsodata/vdso_arm64.go new file mode 100644 index 000000000..67de78969 --- /dev/null +++ b/pkg/sentry/loader/vdsodata/vdso_arm64.go @@ -0,0 +1,7 @@ +// +build arm64 + +// Generated by go_embed_data for //pkg/sentry/loader/vdsodata:vdso_bin. DO NOT EDIT. + +package vdsodata + +var Binary = []byte("ELF\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xb7\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x90\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x008\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00p\xff\xff\xff\xff\xff\x84\x00\x00\x00\x00\x00\x00\x84\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xd0\x00\x00\x00\x00\x00\x00\xd0p\xff\xff\xff\xff\xff\xd0p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x9c\x00\x00\x00\x00\x00\x00\x9cp\xff\xff\xff\xff\xff\x9cp\xff\xff\xff\xff\xff@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00P\xe5td\x00\x00\x00\xdc\x00\x00\x00\x00\x00\x00\xdcp\xff\xff\xff\xff\xff\xdcp\xff\xff\xff\xff\xff<\x00\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xd8p\xff\xff\xff\xff\xff<\x00\x00\x00\x00\x00\x00\x004\x00\x00\x00\x00\x00xp\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\x00J\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff4\x00\x00\x00\x00\x00\x00\x00a\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00linux-vdso.so.1\x00LINUX_2.6.39\x00__kernel_clock_getres\x00__kernel_gettimeofday\x00__kernel_clock_gettime\x00__kernel_rt_sigreturn\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa1\xbf\xee
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x89\xcb_\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00GNU\x00gold 1.16\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00GNU\x004\x93\xb8\xcc$\xaaU\x92\xcd\\\xc0o\xbeb)*\xe1T\x90\x9f;8\x00\x00\x00\x00\x00\x00T
\x00\x00T\x00\x00\x00d
\x00\x00l\x00\x00\x00\x9c
\x00\x00\x84\x00\x00\x00\xfc
\x00\x00\xac\x00\x00\x00<\x00\x00\xc4\x00\x00\x00\xf4\x00\x00\xdc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00zR\x00x\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x004\x00\x00\x00\xf0\x00\x004\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$\x00\x00\x00L\x00\x00\x00
\x00\x00`\x00\x00\x00\x00A0\x9d\x9eB\x93\x94T\xde\xdd\xd3\xd4\x00\x00\x00\x00\x00\x00\x00\x00t\x00\x00\x00H
\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8c\x00\x00\x00p
\x00\x00\xb4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x00\x00\x00\x00\x00\xb4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\xff\xffo\x00\x00\x00\x00Xp\xff\xff\xff\xff\xff\xfc\xff\xffo\x00\x00\x00\x00dp\xff\xff\xff\xff\xff\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00h\x80\xd2\x00\x00\xd4\xc0_\xd6 \xd5\xe3\xaa\x00q\xe0\x00\x00T\x00q\xa0\x00\x00T\xc0\x00\x004(\x80\xd2\x00\x00\xd4\xc0_\xd6\xe0\xaaZ\x00\x00\xe0\xaa*\x00\x00 \xd5\xfd{\xbd\xa9\xfd\x00\x91\xf3S\xa9\xf4\xaa\xc0\x00\xb4\xf3\x00\xaa\xe0\x83\x00\x91!\x00\x00\x94\xa0\x005\xe2B\xa9\xe0\xf9\x9e\xd2`j\xbc\xf2\xa0t\xd3\xf2\x80\xe4\xf2 |@\x9b\x00\xfcG\x93\x00\xfc\x81\xcbb\x00\xa9\x00\x00\x80RT\x00\x00\xb4\x9f\x00\xf9\xf3SA\xa9\xfd{è\xc0_\xd6\x00q\xac\x00\x00T\xc0\x00\xf86H\x80\xd2\x00\x00\xd4\xc0_\xd6\x00q\x81\xff\xffT\xa1\x00\x00\xb4\"\x00\x80\xd2\x00\x00\x80R?\x00\xa9\xc0_\xd6\x00\x00\x80R\xc0_\xd6\xd5 B\xf7\xfeC\x00@\xf9\xe1\x00\xaa\xbf9\xd5cxF\x9cB\xa9e|@\x93D\xa0C\xa9@\xe0;տ9\xd5C\x00@\xf9\x00\xeb\xe1\xfe\xffT\x86\x00\xb4\xff\x00\x00\xeb\x00\x80\xd2L\x00T@\xd9\xd2\x00\x00\xcbCs\xe7\xf2\xfc\x93cȚ|Û\xa2\x9b\x00|\x9bB\x80\xc0\x93\x84\x00\x8bbJ\x8b\xd2b\xb4\xf2@\x99҃\xfcI\xd3\xe2\xd7\xf2\x82\xe0\xf2Es\xa7\xf2\x00\x00\x80Rc|b\xfcK\xd3\"\x00\x00\xf9B\x90\x9b\"\x00\xf9\xc0_\xd6\x00\x00\x80R(\x80\xd2\x00\x00\xd4\xc0_\xd6 Ղ\xf1\xfeC\x00@\xf9\xe1\x00\xaa\xbf9\xd5cxF\x9c@\xa9e|@\x93D\xa0A\xa9@\xe0;տ9\xd5C\x00@\xf9\x00\xeb\xe1\xfe\xffT\x86\x00\xb4\xff\x00\x00\xeb\x00\x80\xd2L\x00T@\xd9\xd2\x00\x00\xcbCs\xe7\xf2\xfc\x93cȚ|Û\xa2\x9b\x00|\x9bB\x80\xc0\x93\x84\x00\x8bbJ\x8b\xd2b\xb4\xf2@\x99҃\xfcI\xd3\xe2\xd7\xf2\x82\xe0\xf2Es\xa7\xf2\x00\x00\x80Rc|b\xfcK\xd3\"\x00\x00\xf9B\x90\x9b\"\x00\xf9\xc0_\xd6 \x00\x80R(\x80\xd2\x00\x00\xd4\xc0_\xd6\x00GCC: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\xb4\x00\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\xd0p\xff\xff\xff\xff\xff\xb4\x00\x00\x00\x00\x00\x00\x00`\x00\x00\x00 \x00\xd0p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00i\x00\x00\x00\x00\x00\xf1\xff\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00v\x00\x00\x00\x00\x00\xf1\xff\x00\xf0o\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00~\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8b\x00\x00\x00\x00\x00\xd8p\xff\xff\xff\xff\xff<\x00\x00\x00\x00\x00\x00\x00\xa1\x00\x00\x00\x00\x00xp\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\x00\xb7\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff4\x00\x00\x00\x00\x00\x00\x00\xce\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00vdso.cc\x00$x\x00vdso_time.cc\x00_ZN4vdso13ClockRealtimeEP8timespec\x00_ZN4vdso14ClockMonotonicEP8timespec\x00_DYNAMIC\x00VDSO_PRELINK\x00_params\x00LINUX_2.6.39\x00__kernel_clock_getres\x00__kernel_gettimeofday\x00__kernel_clock_gettime\x00__kernel_rt_sigreturn\x00\x00.text\x00.comment\x00.bss\x00.dynstr\x00.eh_frame_hdr\x00.gnu.version\x00.dynsym\x00.hash\x00.note\x00.eh_frame\x00.gnu.version_d\x00.dynamic\x00.shstrtab\x00.strtab\x00.symtab\x00.data\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff \x00\x00\x00\x00\x00\x00,\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x90\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\xe0\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00+\x00\x00\x00\xff\xff\xffo\x00\x00\x00\x00\x00\x00\x00Xp\xff\xff\xff\xff\xffX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00V\x00\x00\x00\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00dp\xff\xff\xff\xff\xffd\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00F\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x9cp\xff\xff\xff\xff\xff\x9c\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xdcp\xff\xff\xff\xff\xff\xdc\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00L\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\xb8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xd0p\xff\xff\xff\xff\xff\xd0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x88\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\xe0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000p\xff\xff\xff\xff\xff0\x00\x00\x00\x00\x00\x00T\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x84p\xff\xff\xff\xff\xff\x84\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x84\x00\x00\x00\x00\x00\x00+\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xb0\x00\x00\x00\x00\x00\x00h\x00\x00\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfc\x00\x00\x00\x00\x00\x00\x8e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD deleted file mode 100644 index a89bfa680..000000000 --- a/pkg/sentry/memmap/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "mappable_range", - out = "mappable_range.go", - package = "memmap", - prefix = "Mappable", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "mapping_set_impl", - out = "mapping_set_impl.go", - package = "memmap", - prefix = "Mapping", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "MappableRange", - "Value": "MappingsOfRange", - "Functions": "mappingSetFunctions", - }, -) - -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "memmap", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_library( - name = "memmap", - srcs = [ - "file_range.go", - "mappable_range.go", - "mapping_set.go", - "mapping_set_impl.go", - "memmap.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/hostarch", - "//pkg/log", - "//pkg/safemem", - "//pkg/usermem", - ], -) - -go_test( - name = "memmap_test", - size = "small", - srcs = ["mapping_set_test.go"], - library = ":memmap", - deps = ["//pkg/hostarch"], -) diff --git a/pkg/sentry/memmap/file_range.go b/pkg/sentry/memmap/file_range.go new file mode 100644 index 000000000..6f0b4bde4 --- /dev/null +++ b/pkg/sentry/memmap/file_range.go @@ -0,0 +1,76 @@ +package memmap + +// A Range represents a contiguous range of T. +// +// +stateify savable +type FileRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +// +//go:nosplit +func (r FileRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +// +//go:nosplit +func (r FileRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +// +//go:nosplit +func (r FileRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +// +//go:nosplit +func (r FileRange) Overlaps(r2 FileRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +// +//go:nosplit +func (r FileRange) IsSupersetOf(r2 FileRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +// +//go:nosplit +func (r FileRange) Intersect(r2 FileRange) FileRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +// +//go:nosplit +func (r FileRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/memmap/mappable_range.go b/pkg/sentry/memmap/mappable_range.go new file mode 100644 index 000000000..7b7312cb6 --- /dev/null +++ b/pkg/sentry/memmap/mappable_range.go @@ -0,0 +1,76 @@ +package memmap + +// A Range represents a contiguous range of T. +// +// +stateify savable +type MappableRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +// +//go:nosplit +func (r MappableRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +// +//go:nosplit +func (r MappableRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +// +//go:nosplit +func (r MappableRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +// +//go:nosplit +func (r MappableRange) Overlaps(r2 MappableRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +// +//go:nosplit +func (r MappableRange) IsSupersetOf(r2 MappableRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +// +//go:nosplit +func (r MappableRange) Intersect(r2 MappableRange) MappableRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +// +//go:nosplit +func (r MappableRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/memmap/mapping_set_impl.go b/pkg/sentry/memmap/mapping_set_impl.go new file mode 100644 index 000000000..c32df9259 --- /dev/null +++ b/pkg/sentry/memmap/mapping_set_impl.go @@ -0,0 +1,1643 @@ +package memmap + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const MappingtrackGaps = 0 + +var _ = uint8(MappingtrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type MappingdynamicGap [MappingtrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *MappingdynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *MappingdynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + MappingminDegree = 3 + + MappingmaxDegree = 2 * MappingminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type MappingSet struct { + root Mappingnode `state:".(*MappingSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *MappingSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *MappingSet) IsEmptyRange(r MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *MappingSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *MappingSet) SpanRange(r MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *MappingSet) FirstSegment() MappingIterator { + if s.root.nrSegments == 0 { + return MappingIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *MappingSet) LastSegment() MappingIterator { + if s.root.nrSegments == 0 { + return MappingIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *MappingSet) FirstGap() MappingGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return MappingGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *MappingSet) LastGap() MappingGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return MappingGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *MappingSet) Find(key uint64) (MappingIterator, MappingGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return MappingIterator{n, i}, MappingGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return MappingIterator{}, MappingGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *MappingSet) FindSegment(key uint64) MappingIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *MappingSet) LowerBoundSegment(min uint64) MappingIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *MappingSet) UpperBoundSegment(max uint64) MappingIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *MappingSet) FindGap(key uint64) MappingGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *MappingSet) LowerBoundGap(min uint64) MappingGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *MappingSet) UpperBoundGap(max uint64) MappingGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *MappingSet) Add(r MappableRange, val MappingsOfRange) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *MappingSet) AddWithoutMerging(r MappableRange, val MappingsOfRange) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *MappingSet) Insert(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := MappingtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (mappingSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := MappingtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *MappingSet) InsertWithoutMerging(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *MappingSet) InsertWithoutMergingUnchecked(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := MappingtrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return MappingIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *MappingSet) Remove(seg MappingIterator) MappingGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if MappingtrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + mappingSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if MappingtrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(MappingGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *MappingSet) RemoveAll() { + s.root = Mappingnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *MappingSet) RemoveRange(r MappableRange) MappingGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *MappingSet) Merge(first, second MappingIterator) MappingIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *MappingSet) MergeUnchecked(first, second MappingIterator) MappingIterator { + if first.End() == second.Start() { + if mval, ok := (mappingSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return MappingIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *MappingSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *MappingSet) MergeRange(r MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *MappingSet) MergeAdjacent(r MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *MappingSet) Split(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *MappingSet) SplitUnchecked(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) { + val1, val2 := (mappingSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *MappingSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *MappingSet) Isolate(seg MappingIterator, r MappableRange) MappingIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *MappingSet) ApplyContiguous(r MappableRange, fn func(seg MappingIterator)) MappingGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return MappingGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return MappingGapIterator{} + } + } +} + +// +stateify savable +type Mappingnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Mappingnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap MappingdynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [MappingmaxDegree - 1]MappableRange + values [MappingmaxDegree - 1]MappingsOfRange + children [MappingmaxDegree]*Mappingnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Mappingnode) firstSegment() MappingIterator { + for n.hasChildren { + n = n.children[0] + } + return MappingIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Mappingnode) lastSegment() MappingIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return MappingIterator{n, n.nrSegments - 1} +} + +func (n *Mappingnode) prevSibling() *Mappingnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Mappingnode) nextSibling() *Mappingnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIterator { + if n.nrSegments < MappingmaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:MappingminDegree-1], n.keys[:MappingminDegree-1]) + copy(left.values[:MappingminDegree-1], n.values[:MappingminDegree-1]) + copy(right.keys[:MappingminDegree-1], n.keys[MappingminDegree:]) + copy(right.values[:MappingminDegree-1], n.values[MappingminDegree:]) + n.keys[0], n.values[0] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1] + MappingzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:MappingminDegree], n.children[:MappingminDegree]) + copy(right.children[:MappingminDegree], n.children[MappingminDegree:]) + MappingzeroNodeSlice(n.children[2:]) + for i := 0; i < MappingminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if MappingtrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < MappingminDegree { + return MappingGapIterator{left, gap.index} + } + return MappingGapIterator{right, gap.index - MappingminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:MappingminDegree-1], n.keys[MappingminDegree:]) + copy(sibling.values[:MappingminDegree-1], n.values[MappingminDegree:]) + MappingzeroValueSlice(n.values[MappingminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:MappingminDegree], n.children[MappingminDegree:]) + MappingzeroNodeSlice(n.children[MappingminDegree:]) + for i := 0; i < MappingminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = MappingminDegree - 1 + + if MappingtrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < MappingminDegree { + return gap + } + return MappingGapIterator{sibling, gap.index - MappingminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIterator { + for { + if n.nrSegments >= MappingminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if MappingtrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return MappingGapIterator{n, 0} + } + if gap.node == n { + return MappingGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if MappingtrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return MappingGapIterator{n, n.nrSegments} + } + return MappingGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return MappingGapIterator{p, gap.index} + } + if gap.node == right { + return MappingGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Mappingnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = MappingGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + mappingSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if MappingtrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *Mappingnode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *Mappingnode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *Mappingnode) calculateMaxGapLeaf() uint64 { + max := MappingGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (MappingGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *Mappingnode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *Mappingnode) searchFirstLargeEnoughGap(minSize uint64) MappingGapIterator { + if n.maxGap.Get() < minSize { + return MappingGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := MappingGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *Mappingnode) searchLastLargeEnoughGap(minSize uint64) MappingGapIterator { + if n.maxGap.Get() < minSize { + return MappingGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := MappingGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type MappingIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Mappingnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg MappingIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg MappingIterator) Range() MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg MappingIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg MappingIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg MappingIterator) SetRangeUnchecked(r MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetRange(r MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg MappingIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg MappingIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg MappingIterator) Value() MappingsOfRange { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg MappingIterator) ValuePtr() *MappingsOfRange { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg MappingIterator) SetValue(val MappingsOfRange) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg MappingIterator) PrevSegment() MappingIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return MappingIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return MappingIterator{} + } + return MappingsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg MappingIterator) NextSegment() MappingIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return MappingIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return MappingIterator{} + } + return MappingsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg MappingIterator) PrevGap() MappingGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return MappingGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg MappingIterator) NextGap() MappingGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return MappingGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg MappingIterator) PrevNonEmpty() (MappingIterator, MappingGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return MappingIterator{}, gap + } + return gap.PrevSegment(), MappingGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg MappingIterator) NextNonEmpty() (MappingIterator, MappingGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return MappingIterator{}, gap + } + return gap.NextSegment(), MappingGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type MappingGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Mappingnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap MappingGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap MappingGapIterator) Range() MappableRange { + return MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap MappingGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return mappingSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap MappingGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return mappingSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap MappingGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap MappingGapIterator) PrevSegment() MappingIterator { + return MappingsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap MappingGapIterator) NextSegment() MappingIterator { + return MappingsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap MappingGapIterator) PrevGap() MappingGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return MappingGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap MappingGapIterator) NextGap() MappingGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return MappingGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap MappingGapIterator) NextLargeEnoughGap(minSize uint64) MappingGapIterator { + if MappingtrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap MappingGapIterator) nextLargeEnoughGapHelper(minSize uint64) MappingGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return MappingGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap MappingGapIterator) PrevLargeEnoughGap(minSize uint64) MappingGapIterator { + if MappingtrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap MappingGapIterator) prevLargeEnoughGapHelper(minSize uint64) MappingGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return MappingGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func MappingsegmentBeforePosition(n *Mappingnode, i int) MappingIterator { + for i == 0 { + if n.parent == nil { + return MappingIterator{} + } + n, i = n.parent, n.parentIndex + } + return MappingIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func MappingsegmentAfterPosition(n *Mappingnode, i int) MappingIterator { + for i == n.nrSegments { + if n.parent == nil { + return MappingIterator{} + } + n, i = n.parent, n.parentIndex + } + return MappingIterator{n, i} +} + +func MappingzeroValueSlice(slice []MappingsOfRange) { + + for i := range slice { + mappingSetFunctions{}.ClearValue(&slice[i]) + } +} + +func MappingzeroNodeSlice(slice []*Mappingnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *MappingSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *Mappingnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Mappingnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if MappingtrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type MappingSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []MappingsOfRange +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *MappingSet) ExportSortedSlices() *MappingSegmentDataSlices { + var sds MappingSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *MappingSet) ImportSortedSlices(sds *MappingSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *MappingSet) segmentTestCheck(expectedSegments int, segFunc func(int, MappableRange, MappingsOfRange) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *MappingSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *MappingSet) saveRoot() *MappingSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *MappingSet) loadRoot(sds *MappingSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go deleted file mode 100644 index 5cb81fde7..000000000 --- a/pkg/sentry/memmap/mapping_set_test.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memmap - -import ( - "gvisor.dev/gvisor/pkg/hostarch" - "reflect" - "testing" -) - -type testMappingSpace struct { - // Ideally we'd store the full ranges that were invalidated, rather - // than individual calls to Invalidate, as they are an implementation - // detail, but this is the simplest way for now. - inv []hostarch.AddrRange -} - -func (n *testMappingSpace) reset() { - n.inv = []hostarch.AddrRange{} -} - -func (n *testMappingSpace) Invalidate(ar hostarch.AddrRange, opts InvalidateOpts) { - n.inv = append(n.inv, ar) -} - -func TestAddRemoveMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - mapped := set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings (hostarch.AddrRanges => memmap.MappableRange): - // [0x10000, 0x12000) => [0x1000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(mapped) != 0 { - t.Errorf("AddMapping: got %+v, wanted []", mapped) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, hostarch.AddrRange{0x30000, 0x31000}, 0x4000, true) - if got, want := mapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x30000, 0x31000) => [0x4000, 0x5000) - t.Log(&set) - - mapped = set.AddMapping(ms, hostarch.AddrRange{0x12000, 0x15000}, 0x3000, true) - if got, want := mapped, []MappableRange{{0x3000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x12000, 0x13000) => [0x3000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped := set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0x1000, true) - if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x12000, 0x13000) => [0x3000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Mappings: - // [0x11000, 0x13000) => [0x2000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x11000, 0x15000}, 0x2000, true) - if got, want := unmapped, []MappableRange{{0x2000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x30000, 0x31000) => [0x4000, 0x5000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x30000, 0x31000}, 0x4000, true) - if got, want := unmapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateWholeMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0, true) - // Mappings: - // [0x10000, 0x11000) => [0, 0x1000) - t.Log(&set) - set.Invalidate(MappableRange{0, 0x1000}, InvalidateOpts{}) - if got, want := ms.inv, []hostarch.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidatePartialMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x13000}, 0, true) - // Mappings: - // [0x10000, 0x13000) => [0, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms.inv, []hostarch.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateMultipleMappings(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0, true) - set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) - // Mappings: - // [0x10000, 0x11000) => [0, 0x1000) - // [0x12000, 0x13000) => [0x2000, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0, 0x3000}, InvalidateOpts{}) - if got, want := ms.inv, []hostarch.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateOverlappingMappings(t *testing.T) { - set := MappingSet{} - ms1 := &testMappingSpace{} - ms2 := &testMappingSpace{} - - set.AddMapping(ms1, hostarch.AddrRange{0x10000, 0x12000}, 0, true) - set.AddMapping(ms2, hostarch.AddrRange{0x20000, 0x22000}, 0x1000, true) - // Mappings: - // ms1:[0x10000, 0x12000) => [0, 0x2000) - // ms2:[0x11000, 0x13000) => [0x1000, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms1.inv, []hostarch.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) - } - if got, want := ms2.inv, []hostarch.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) - } -} - -func TestMixedWritableMappings(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - mapped := set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x22000}, 0x2000, false) - if got, want := mapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) writable => [0x1000, 0x2000) - // [0x11000, 0x12000) writable and [0x20000, 0x21000) readonly => [0x2000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - // Unmap should fail because we specified the readonly map address range, but - // asked to unmap a writable segment. - unmapped := set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Readonly mapping removed, but writable mapping still exists in the range, - // so no mappable range fully unmapped. - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, false) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x11000, 0x12000}, 0x2000, true) - if got, want := unmapped, []MappableRange{{0x2000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - // Unmap should fail since writable bit doesn't match. - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, false) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x21000, 0x22000}, 0x3000, false) - if got, want := unmapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } -} diff --git a/pkg/sentry/memmap/memmap_impl_state_autogen.go b/pkg/sentry/memmap/memmap_impl_state_autogen.go new file mode 100644 index 000000000..02c96e66d --- /dev/null +++ b/pkg/sentry/memmap/memmap_impl_state_autogen.go @@ -0,0 +1,117 @@ +// automatically generated by stateify. + +package memmap + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *MappingSet) StateTypeName() string { + return "pkg/sentry/memmap.MappingSet" +} + +func (s *MappingSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *MappingSet) beforeSave() {} + +// +checklocksignore +func (s *MappingSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *MappingSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *MappingSet) afterLoad() {} + +// +checklocksignore +func (s *MappingSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*MappingSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*MappingSegmentDataSlices)) }) +} + +func (n *Mappingnode) StateTypeName() string { + return "pkg/sentry/memmap.Mappingnode" +} + +func (n *Mappingnode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *Mappingnode) beforeSave() {} + +// +checklocksignore +func (n *Mappingnode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *Mappingnode) afterLoad() {} + +// +checklocksignore +func (n *Mappingnode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (m *MappingSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/memmap.MappingSegmentDataSlices" +} + +func (m *MappingSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (m *MappingSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (m *MappingSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.Start) + stateSinkObject.Save(1, &m.End) + stateSinkObject.Save(2, &m.Values) +} + +func (m *MappingSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (m *MappingSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.Start) + stateSourceObject.Load(1, &m.End) + stateSourceObject.Load(2, &m.Values) +} + +func init() { + state.Register((*MappingSet)(nil)) + state.Register((*Mappingnode)(nil)) + state.Register((*MappingSegmentDataSlices)(nil)) +} diff --git a/pkg/sentry/memmap/memmap_state_autogen.go b/pkg/sentry/memmap/memmap_state_autogen.go new file mode 100644 index 000000000..8624ecb60 --- /dev/null +++ b/pkg/sentry/memmap/memmap_state_autogen.go @@ -0,0 +1,100 @@ +// automatically generated by stateify. + +package memmap + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fr *FileRange) StateTypeName() string { + return "pkg/sentry/memmap.FileRange" +} + +func (fr *FileRange) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (fr *FileRange) beforeSave() {} + +// +checklocksignore +func (fr *FileRange) StateSave(stateSinkObject state.Sink) { + fr.beforeSave() + stateSinkObject.Save(0, &fr.Start) + stateSinkObject.Save(1, &fr.End) +} + +func (fr *FileRange) afterLoad() {} + +// +checklocksignore +func (fr *FileRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fr.Start) + stateSourceObject.Load(1, &fr.End) +} + +func (mr *MappableRange) StateTypeName() string { + return "pkg/sentry/memmap.MappableRange" +} + +func (mr *MappableRange) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (mr *MappableRange) beforeSave() {} + +// +checklocksignore +func (mr *MappableRange) StateSave(stateSinkObject state.Sink) { + mr.beforeSave() + stateSinkObject.Save(0, &mr.Start) + stateSinkObject.Save(1, &mr.End) +} + +func (mr *MappableRange) afterLoad() {} + +// +checklocksignore +func (mr *MappableRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mr.Start) + stateSourceObject.Load(1, &mr.End) +} + +func (r *MappingOfRange) StateTypeName() string { + return "pkg/sentry/memmap.MappingOfRange" +} + +func (r *MappingOfRange) StateFields() []string { + return []string{ + "MappingSpace", + "AddrRange", + "Writable", + } +} + +func (r *MappingOfRange) beforeSave() {} + +// +checklocksignore +func (r *MappingOfRange) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.MappingSpace) + stateSinkObject.Save(1, &r.AddrRange) + stateSinkObject.Save(2, &r.Writable) +} + +func (r *MappingOfRange) afterLoad() {} + +// +checklocksignore +func (r *MappingOfRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.MappingSpace) + stateSourceObject.Load(1, &r.AddrRange) + stateSourceObject.Load(2, &r.Writable) +} + +func init() { + state.Register((*FileRange)(nil)) + state.Register((*MappableRange)(nil)) + state.Register((*MappingOfRange)(nil)) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD deleted file mode 100644 index b7d782b7f..000000000 --- a/pkg/sentry/mm/BUILD +++ /dev/null @@ -1,169 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "file_refcount_set", - out = "file_refcount_set.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - }, - package = "mm", - prefix = "fileRefcount", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.FileRange", - "Value": "int32", - "Functions": "fileRefcountSetFunctions", - }, -) - -go_template_instance( - name = "vma_set", - out = "vma_set.go", - consts = { - "minDegree": "8", - "trackGaps": "1", - }, - imports = { - "hostarch": "gvisor.dev/gvisor/pkg/hostarch", - }, - package = "mm", - prefix = "vma", - template = "//pkg/segment:generic_set", - types = { - "Key": "hostarch.Addr", - "Range": "hostarch.AddrRange", - "Value": "vma", - "Functions": "vmaSetFunctions", - }, -) - -go_template_instance( - name = "pma_set", - out = "pma_set.go", - consts = { - "minDegree": "8", - }, - imports = { - "hostarch": "gvisor.dev/gvisor/pkg/hostarch", - }, - package = "mm", - prefix = "pma", - template = "//pkg/segment:generic_set", - types = { - "Key": "hostarch.Addr", - "Range": "hostarch.AddrRange", - "Value": "pma", - "Functions": "pmaSetFunctions", - }, -) - -go_template_instance( - name = "io_list", - out = "io_list.go", - package = "mm", - prefix = "io", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*ioResult", - "Linker": "*ioResult", - }, -) - -go_template_instance( - name = "aio_mappable_refs", - out = "aio_mappable_refs.go", - package = "mm", - prefix = "aioMappable", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "aioMappable", - }, -) - -go_template_instance( - name = "special_mappable_refs", - out = "special_mappable_refs.go", - package = "mm", - prefix = "SpecialMappable", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "SpecialMappable", - }, -) - -go_library( - name = "mm", - srcs = [ - "address_space.go", - "aio_context.go", - "aio_context_state.go", - "aio_mappable_refs.go", - "debug.go", - "file_refcount_set.go", - "io.go", - "io_list.go", - "lifecycle.go", - "metadata.go", - "mm.go", - "pma.go", - "pma_set.go", - "procfs.go", - "save_restore.go", - "shm.go", - "special_mappable.go", - "special_mappable_refs.go", - "syscalls.go", - "vma.go", - "vma_set.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safecopy", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs/proc/seqfile", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/futex", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sync", - "//pkg/tcpip/buffer", - "//pkg/usermem", - ], -) - -go_test( - name = "mm_test", - size = "small", - srcs = ["mm_test.go"], - library = ":mm", - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/sentry/arch", - "//pkg/sentry/contexttest", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/mm/README.md b/pkg/sentry/mm/README.md deleted file mode 100644 index f4d43d927..000000000 --- a/pkg/sentry/mm/README.md +++ /dev/null @@ -1,280 +0,0 @@ -This package provides an emulation of Linux semantics for application virtual -memory mappings. - -For completeness, this document also describes aspects of the memory management -subsystem defined outside this package. - -# Background - -We begin by describing semantics for virtual memory in Linux. - -A virtual address space is defined as a collection of mappings from virtual -addresses to physical memory. However, userspace applications do not configure -mappings to physical memory directly. Instead, applications configure memory -mappings from virtual addresses to offsets into a file using the `mmap` system -call.[^mmap-anon] For example, a call to: - - mmap( - /* addr = */ 0x400000, - /* length = */ 0x1000, - PROT_READ | PROT_WRITE, - MAP_SHARED, - /* fd = */ 3, - /* offset = */ 0); - -creates a mapping of length 0x1000 bytes, starting at virtual address (VA) -0x400000, to offset 0 in the file represented by file descriptor (FD) 3. Within -the Linux kernel, virtual memory mappings are represented by *virtual memory -areas* (VMAs). Supposing that FD 3 represents file /tmp/foo, the state of the -virtual memory subsystem after the `mmap` call may be depicted as: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - -Establishing a virtual memory area does not necessarily establish a mapping to a -physical address, because Linux has not necessarily provisioned physical memory -to store the file's contents. Thus, if the application attempts to read the -contents of VA 0x400000, it may incur a *page fault*, a CPU exception that -forces the kernel to create such a mapping to service the read. - -For a file, doing so consists of several logical phases: - -1. The kernel allocates physical memory to store the contents of the required - part of the file, and copies file contents to the allocated memory. - Supposing that the kernel chooses the physical memory at physical address - (PA) 0x2fb000, the resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - - (In Linux the state of the mapping from file offset to physical memory is - stored in `struct address_space`, but to avoid confusion with other notions - of address space we will refer to this system as filemap, named after Linux - kernel source file `mm/filemap.c`.) - -2. The kernel stores the effective mapping from virtual to physical address in - a *page table entry* (PTE) in the application's *page tables*, which are - used by the CPU's virtual memory hardware to perform address translation. - The resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x2fb000 - - The PTE is required for the application to actually use the contents of the - mapped file as virtual memory. However, the PTE is derived from the VMA and - filemap state, both of which are independently mutable, such that mutations - to either will affect the PTE. For example: - - - The application may remove the VMA using the `munmap` system call. This - breaks the mapping from VA:0x400000 to /tmp/foo:0x0, and consequently - the mapping from VA:0x400000 to PA:0x2fb000. However, it does not - necessarily break the mapping from /tmp/foo:0x0 to PA:0x2fb000, so a - future mapping of the same file offset may reuse this physical memory. - - - The application may invalidate the file's contents by passing a length - of 0 to the `ftruncate` system call. This breaks the mapping from - /tmp/foo:0x0 to PA:0x2fb000, and consequently the mapping from - VA:0x400000 to PA:0x2fb000. However, it does not break the mapping from - VA:0x400000 to /tmp/foo:0x0, so future changes to the file's contents - may again be made visible at VA:0x400000 after another page fault - results in the allocation of a new physical address. - - Note that, in order to correctly break the mapping from VA:0x400000 to - PA:0x2fb000 in the latter case, filemap must also store a *reverse mapping* - from /tmp/foo:0x0 to VA:0x400000 so that it can locate and remove the PTE. - -[^mmap-anon]: Memory mappings to non-files are discussed in later sections. - -## Private Mappings - -The preceding example considered VMAs created using the `MAP_SHARED` flag, which -means that PTEs derived from the mapping should always use physical memory that -represents the current state of the mapped file.[^mmap-dev-zero] Applications -can alternatively pass the `MAP_PRIVATE` flag to create a *private mapping*. -Private mappings are *copy-on-write*. - -Suppose that the application instead created a private mapping in the previous -example. In Linux, the state of the system after a read page fault would be: - - VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x2fb000 (read-only) - -Now suppose the application attempts to write to VA:0x400000. For a shared -mapping, the write would be propagated to PA:0x2fb000, and the kernel would be -responsible for ensuring that the write is later propagated to the mapped file. -For a private mapping, the write incurs another page fault since the PTE is -marked read-only. In response, the kernel allocates physical memory to store the -mapping's *private copy* of the file's contents, copies file contents to the -allocated memory, and changes the PTE to map to the private copy. Supposing that -the kernel chooses the physical memory at physical address (PA) 0x5ea000, the -resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x5ea000 - -Note that the filemap mapping from /tmp/foo:0x0 to PA:0x2fb000 may still exist, -but is now irrelevant to this mapping. - -[^mmap-dev-zero]: Modulo files with special mmap semantics such as `/dev/zero`. - -## Anonymous Mappings - -Instead of passing a file to the `mmap` system call, applications can instead -request an *anonymous* mapping by passing the `MAP_ANONYMOUS` flag. -Semantically, an anonymous mapping is essentially a mapping to an ephemeral file -initially filled with zero bytes. Practically speaking, this is how shared -anonymous mappings are implemented, but private anonymous mappings do not result -in the creation of an ephemeral file; since there would be no way to modify the -contents of the underlying file through a private mapping, all private anonymous -mappings use a single shared page filled with zero bytes until copy-on-write -occurs. - -# Virtual Memory in the Sentry - -The sentry implements application virtual memory atop a host kernel, introducing -an additional level of indirection to the above. - -Consider the same scenario as in the previous section. Since the sentry handles -application system calls, the effect of an application `mmap` system call is to -create a VMA in the sentry (as opposed to the host kernel): - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - -When the application first incurs a page fault on this address, the host kernel -delivers information about the page fault to the sentry in a platform-dependent -manner, and the sentry handles the fault: - -1. The sentry allocates memory to store the contents of the required part of - the file, and copies file contents to the allocated memory. However, since - the sentry is implemented atop a host kernel, it does not configure mappings - to physical memory directly. Instead, mappable "memory" in the sentry is - represented by a host file descriptor and offset, since (as noted in - "Background") this is the memory mapping primitive provided by the host - kernel. In general, memory is allocated from a temporary host file using the - `pgalloc` package. Supposing that the sentry allocates offset 0x3000 from - host file "memory-file", the resulting state is: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - -2. The sentry stores the effective mapping from virtual address to host file in - a host VMA by invoking the `mmap` system call: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 - -3. The sentry returns control to the application, which immediately incurs the - page fault again.[^mmap-populate] However, since a host VMA now exists for - the faulting virtual address, the host kernel now handles the page fault as - described in "Background": - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 - Host filemap: host:memory-file:0x3000 -> PA:0x2fb000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 - -Thus, from an implementation standpoint, host VMAs serve the same purpose in the -sentry that PTEs do in Linux. As in Linux, sentry VMA and filemap state is -independently mutable, and the desired state of host VMAs is derived from that -state. - -[^mmap-populate]: The sentry could force the host kernel to establish PTEs when - it creates the host VMA by passing the `MAP_POPULATE` flag to - the `mmap` system call, but usually does not. This is because, - to reduce the number of page faults that require handling by - the sentry and (correspondingly) the number of host `mmap` - system calls, the sentry usually creates host VMAs that are - much larger than the single faulting page. - -## Private Mappings - -The sentry implements private mappings consistently with Linux. Before -copy-on-write, the private mapping example given in the Background results in: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 (read-only) - Host filemap: host:memory-file:0x3000 -> PA:0x2fb000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 (read-only) - -When the application attempts to write to this address, the host kernel delivers -information about the resulting page fault to the sentry. Analogous to Linux, -the sentry allocates memory to store the mapping's private copy of the file's -contents, copies file contents to the allocated memory, and changes the host VMA -to map to the private copy. Supposing that the sentry chooses the offset 0x4000 -in host file `memory-file` to store the private copy, the state of the system -after copy-on-write is: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x4000 - Host filemap: host:memory-file:0x4000 -> PA:0x5ea000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x5ea000 - -However, this highlights an important difference between Linux and the sentry. -In Linux, page tables are concrete (architecture-dependent) data structures -owned by the kernel. Conversely, the sentry has the ability to create and -destroy host VMAs using host system calls, but it does not have direct access to -their state. Thus, as written, if the application invokes the `munmap` system -call to remove the sentry VMA, it is non-trivial for the sentry to determine -that it should deallocate `host:memory-file:0x4000`. This implies that the -sentry must retain information about the host VMAs that it has created. - -## Anonymous Mappings - -The sentry implements anonymous mappings consistently with Linux, except that -there is no shared zero page. - -# Implementation Constructs - -In Linux: - -- A virtual address space is represented by `struct mm_struct`. - -- VMAs are represented by `struct vm_area_struct`, stored in `struct - mm_struct::mmap`. - -- Mappings from file offsets to physical memory are stored in `struct - address_space`. - -- Reverse mappings from file offsets to virtual mappings are stored in `struct - address_space::i_mmap`. - -- Physical memory pages are represented by a pointer to `struct page` or an - index called a *page frame number* (PFN), represented by `pfn_t`. - -- PTEs are represented by architecture-dependent type `pte_t`, stored in a - table hierarchy rooted at `struct mm_struct::pgd`. - -In the sentry: - -- A virtual address space is represented by type [`mm.MemoryManager`][mm]. - -- Sentry VMAs are represented by type [`mm.vma`][mm], stored in - `mm.MemoryManager.vmas`. - -- Mappings from sentry file offsets to host file offsets are abstracted - through interface method [`memmap.Mappable.Translate`][memmap]. - -- Reverse mappings from sentry file offsets to virtual mappings are abstracted - through interface methods - [`memmap.Mappable.AddMapping` and `memmap.Mappable.RemoveMapping`][memmap]. - -- Host files that may be mapped into host VMAs are represented by type - [`platform.File`][platform]. - -- Host VMAs are represented in the sentry by type [`mm.pma`][mm] ("platform - mapping area"), stored in `mm.MemoryManager.pmas`. - -- Creation and destruction of host VMAs is abstracted through interface - methods - [`platform.AddressSpace.MapFile` and `platform.AddressSpace.Unmap`][platform]. - -[memmap]: https://github.com/google/gvisor/blob/master/pkg/sentry/memmap/memmap.go -[mm]: https://github.com/google/gvisor/blob/master/pkg/sentry/mm/mm.go -[pgalloc]: https://github.com/google/gvisor/blob/master/pkg/sentry/pgalloc/pgalloc.go -[platform]: https://github.com/google/gvisor/blob/master/pkg/sentry/platform/platform.go diff --git a/pkg/sentry/mm/aio_mappable_refs.go b/pkg/sentry/mm/aio_mappable_refs.go new file mode 100644 index 000000000..6e1bc6739 --- /dev/null +++ b/pkg/sentry/mm/aio_mappable_refs.go @@ -0,0 +1,140 @@ +package mm + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const aioMappableenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var aioMappableobj *aioMappable + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type aioMappableRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *aioMappableRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *aioMappableRefs) RefType() string { + return fmt.Sprintf("%T", aioMappableobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *aioMappableRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *aioMappableRefs) LogRefs() bool { + return aioMappableenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *aioMappableRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *aioMappableRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if aioMappableenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *aioMappableRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if aioMappableenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *aioMappableRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if aioMappableenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *aioMappableRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/mm/file_refcount_set.go b/pkg/sentry/mm/file_refcount_set.go new file mode 100644 index 000000000..602a137d4 --- /dev/null +++ b/pkg/sentry/mm/file_refcount_set.go @@ -0,0 +1,1647 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const fileRefcounttrackGaps = 0 + +var _ = uint8(fileRefcounttrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type fileRefcountdynamicGap [fileRefcounttrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *fileRefcountdynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *fileRefcountdynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + fileRefcountminDegree = 3 + + fileRefcountmaxDegree = 2 * fileRefcountminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type fileRefcountSet struct { + root fileRefcountnode `state:".(*fileRefcountSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *fileRefcountSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *fileRefcountSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *fileRefcountSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *fileRefcountSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *fileRefcountSet) FirstSegment() fileRefcountIterator { + if s.root.nrSegments == 0 { + return fileRefcountIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *fileRefcountSet) LastSegment() fileRefcountIterator { + if s.root.nrSegments == 0 { + return fileRefcountIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *fileRefcountSet) FirstGap() fileRefcountGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return fileRefcountGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *fileRefcountSet) LastGap() fileRefcountGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return fileRefcountGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *fileRefcountSet) Find(key uint64) (fileRefcountIterator, fileRefcountGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return fileRefcountIterator{n, i}, fileRefcountGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return fileRefcountIterator{}, fileRefcountGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *fileRefcountSet) FindSegment(key uint64) fileRefcountIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *fileRefcountSet) LowerBoundSegment(min uint64) fileRefcountIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *fileRefcountSet) UpperBoundSegment(max uint64) fileRefcountIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *fileRefcountSet) FindGap(key uint64) fileRefcountGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *fileRefcountSet) LowerBoundGap(min uint64) fileRefcountGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *fileRefcountSet) UpperBoundGap(max uint64) fileRefcountGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *fileRefcountSet) Add(r __generics_imported0.FileRange, val int32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *fileRefcountSet) AddWithoutMerging(r __generics_imported0.FileRange, val int32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *fileRefcountSet) Insert(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := fileRefcounttrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (fileRefcountSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := fileRefcounttrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *fileRefcountSet) InsertWithoutMerging(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *fileRefcountSet) InsertWithoutMergingUnchecked(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := fileRefcounttrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return fileRefcountIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *fileRefcountSet) Remove(seg fileRefcountIterator) fileRefcountGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if fileRefcounttrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + fileRefcountSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if fileRefcounttrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(fileRefcountGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *fileRefcountSet) RemoveAll() { + s.root = fileRefcountnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *fileRefcountSet) RemoveRange(r __generics_imported0.FileRange) fileRefcountGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *fileRefcountSet) Merge(first, second fileRefcountIterator) fileRefcountIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *fileRefcountSet) MergeUnchecked(first, second fileRefcountIterator) fileRefcountIterator { + if first.End() == second.Start() { + if mval, ok := (fileRefcountSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return fileRefcountIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *fileRefcountSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *fileRefcountSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *fileRefcountSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *fileRefcountSet) Split(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *fileRefcountSet) SplitUnchecked(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) { + val1, val2 := (fileRefcountSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *fileRefcountSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *fileRefcountSet) Isolate(seg fileRefcountIterator, r __generics_imported0.FileRange) fileRefcountIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *fileRefcountSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg fileRefcountIterator)) fileRefcountGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return fileRefcountGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return fileRefcountGapIterator{} + } + } +} + +// +stateify savable +type fileRefcountnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *fileRefcountnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap fileRefcountdynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [fileRefcountmaxDegree - 1]__generics_imported0.FileRange + values [fileRefcountmaxDegree - 1]int32 + children [fileRefcountmaxDegree]*fileRefcountnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *fileRefcountnode) firstSegment() fileRefcountIterator { + for n.hasChildren { + n = n.children[0] + } + return fileRefcountIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *fileRefcountnode) lastSegment() fileRefcountIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return fileRefcountIterator{n, n.nrSegments - 1} +} + +func (n *fileRefcountnode) prevSibling() *fileRefcountnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *fileRefcountnode) nextSibling() *fileRefcountnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *fileRefcountnode) rebalanceBeforeInsert(gap fileRefcountGapIterator) fileRefcountGapIterator { + if n.nrSegments < fileRefcountmaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:fileRefcountminDegree-1], n.keys[:fileRefcountminDegree-1]) + copy(left.values[:fileRefcountminDegree-1], n.values[:fileRefcountminDegree-1]) + copy(right.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:]) + copy(right.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:]) + n.keys[0], n.values[0] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1] + fileRefcountzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:fileRefcountminDegree], n.children[:fileRefcountminDegree]) + copy(right.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:]) + fileRefcountzeroNodeSlice(n.children[2:]) + for i := 0; i < fileRefcountminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if fileRefcounttrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < fileRefcountminDegree { + return fileRefcountGapIterator{left, gap.index} + } + return fileRefcountGapIterator{right, gap.index - fileRefcountminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:]) + copy(sibling.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:]) + fileRefcountzeroValueSlice(n.values[fileRefcountminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:]) + fileRefcountzeroNodeSlice(n.children[fileRefcountminDegree:]) + for i := 0; i < fileRefcountminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = fileRefcountminDegree - 1 + + if fileRefcounttrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < fileRefcountminDegree { + return gap + } + return fileRefcountGapIterator{sibling, gap.index - fileRefcountminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *fileRefcountnode) rebalanceAfterRemove(gap fileRefcountGapIterator) fileRefcountGapIterator { + for { + if n.nrSegments >= fileRefcountminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if fileRefcounttrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return fileRefcountGapIterator{n, 0} + } + if gap.node == n { + return fileRefcountGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if fileRefcounttrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return fileRefcountGapIterator{n, n.nrSegments} + } + return fileRefcountGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return fileRefcountGapIterator{p, gap.index} + } + if gap.node == right { + return fileRefcountGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *fileRefcountnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = fileRefcountGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + fileRefcountSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if fileRefcounttrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *fileRefcountnode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *fileRefcountnode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *fileRefcountnode) calculateMaxGapLeaf() uint64 { + max := fileRefcountGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (fileRefcountGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *fileRefcountnode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *fileRefcountnode) searchFirstLargeEnoughGap(minSize uint64) fileRefcountGapIterator { + if n.maxGap.Get() < minSize { + return fileRefcountGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := fileRefcountGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *fileRefcountnode) searchLastLargeEnoughGap(minSize uint64) fileRefcountGapIterator { + if n.maxGap.Get() < minSize { + return fileRefcountGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := fileRefcountGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type fileRefcountIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *fileRefcountnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg fileRefcountIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg fileRefcountIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg fileRefcountIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg fileRefcountIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg fileRefcountIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg fileRefcountIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg fileRefcountIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg fileRefcountIterator) Value() int32 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg fileRefcountIterator) ValuePtr() *int32 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg fileRefcountIterator) SetValue(val int32) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg fileRefcountIterator) PrevSegment() fileRefcountIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return fileRefcountIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return fileRefcountIterator{} + } + return fileRefcountsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg fileRefcountIterator) NextSegment() fileRefcountIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return fileRefcountIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return fileRefcountIterator{} + } + return fileRefcountsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg fileRefcountIterator) PrevGap() fileRefcountGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return fileRefcountGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg fileRefcountIterator) NextGap() fileRefcountGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return fileRefcountGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg fileRefcountIterator) PrevNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return fileRefcountIterator{}, gap + } + return gap.PrevSegment(), fileRefcountGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg fileRefcountIterator) NextNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return fileRefcountIterator{}, gap + } + return gap.NextSegment(), fileRefcountGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type fileRefcountGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *fileRefcountnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap fileRefcountGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap fileRefcountGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap fileRefcountGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return fileRefcountSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap fileRefcountGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return fileRefcountSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap fileRefcountGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap fileRefcountGapIterator) PrevSegment() fileRefcountIterator { + return fileRefcountsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap fileRefcountGapIterator) NextSegment() fileRefcountIterator { + return fileRefcountsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap fileRefcountGapIterator) PrevGap() fileRefcountGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return fileRefcountGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap fileRefcountGapIterator) NextGap() fileRefcountGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return fileRefcountGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap fileRefcountGapIterator) NextLargeEnoughGap(minSize uint64) fileRefcountGapIterator { + if fileRefcounttrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap fileRefcountGapIterator) nextLargeEnoughGapHelper(minSize uint64) fileRefcountGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return fileRefcountGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap fileRefcountGapIterator) PrevLargeEnoughGap(minSize uint64) fileRefcountGapIterator { + if fileRefcounttrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap fileRefcountGapIterator) prevLargeEnoughGapHelper(minSize uint64) fileRefcountGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return fileRefcountGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func fileRefcountsegmentBeforePosition(n *fileRefcountnode, i int) fileRefcountIterator { + for i == 0 { + if n.parent == nil { + return fileRefcountIterator{} + } + n, i = n.parent, n.parentIndex + } + return fileRefcountIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func fileRefcountsegmentAfterPosition(n *fileRefcountnode, i int) fileRefcountIterator { + for i == n.nrSegments { + if n.parent == nil { + return fileRefcountIterator{} + } + n, i = n.parent, n.parentIndex + } + return fileRefcountIterator{n, i} +} + +func fileRefcountzeroValueSlice(slice []int32) { + + for i := range slice { + fileRefcountSetFunctions{}.ClearValue(&slice[i]) + } +} + +func fileRefcountzeroNodeSlice(slice []*fileRefcountnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *fileRefcountSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *fileRefcountnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *fileRefcountnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if fileRefcounttrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type fileRefcountSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []int32 +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *fileRefcountSet) ExportSortedSlices() *fileRefcountSegmentDataSlices { + var sds fileRefcountSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *fileRefcountSet) ImportSortedSlices(sds *fileRefcountSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *fileRefcountSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.FileRange, int32) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *fileRefcountSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *fileRefcountSet) saveRoot() *fileRefcountSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *fileRefcountSet) loadRoot(sds *fileRefcountSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/mm/io_list.go b/pkg/sentry/mm/io_list.go new file mode 100644 index 000000000..9a54e60be --- /dev/null +++ b/pkg/sentry/mm/io_list.go @@ -0,0 +1,221 @@ +package mm + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type ioElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (ioElementMapper) linkerFor(elem *ioResult) *ioResult { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type ioList struct { + head *ioResult + tail *ioResult +} + +// Reset resets list l to the empty state. +func (l *ioList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *ioList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *ioList) Front() *ioResult { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *ioList) Back() *ioResult { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *ioList) Len() (count int) { + for e := l.Front(); e != nil; e = (ioElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *ioList) PushFront(e *ioResult) { + linker := ioElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + ioElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *ioList) PushBack(e *ioResult) { + linker := ioElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + ioElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *ioList) PushBackList(m *ioList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + ioElementMapper{}.linkerFor(l.tail).SetNext(m.head) + ioElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *ioList) InsertAfter(b, e *ioResult) { + bLinker := ioElementMapper{}.linkerFor(b) + eLinker := ioElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + ioElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *ioList) InsertBefore(a, e *ioResult) { + aLinker := ioElementMapper{}.linkerFor(a) + eLinker := ioElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + ioElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *ioList) Remove(e *ioResult) { + linker := ioElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + ioElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + ioElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type ioEntry struct { + next *ioResult + prev *ioResult +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *ioEntry) Next() *ioResult { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *ioEntry) Prev() *ioResult { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *ioEntry) SetNext(elem *ioResult) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *ioEntry) SetPrev(elem *ioResult) { + e.prev = elem +} diff --git a/pkg/sentry/mm/mm_state_autogen.go b/pkg/sentry/mm/mm_state_autogen.go new file mode 100644 index 000000000..cf5992fe1 --- /dev/null +++ b/pkg/sentry/mm/mm_state_autogen.go @@ -0,0 +1,812 @@ +// automatically generated by stateify. + +package mm + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (a *aioManager) StateTypeName() string { + return "pkg/sentry/mm.aioManager" +} + +func (a *aioManager) StateFields() []string { + return []string{ + "contexts", + } +} + +func (a *aioManager) beforeSave() {} + +// +checklocksignore +func (a *aioManager) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + stateSinkObject.Save(0, &a.contexts) +} + +func (a *aioManager) afterLoad() {} + +// +checklocksignore +func (a *aioManager) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &a.contexts) +} + +func (i *ioResult) StateTypeName() string { + return "pkg/sentry/mm.ioResult" +} + +func (i *ioResult) StateFields() []string { + return []string{ + "data", + "ioEntry", + } +} + +func (i *ioResult) beforeSave() {} + +// +checklocksignore +func (i *ioResult) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.data) + stateSinkObject.Save(1, &i.ioEntry) +} + +func (i *ioResult) afterLoad() {} + +// +checklocksignore +func (i *ioResult) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.data) + stateSourceObject.Load(1, &i.ioEntry) +} + +func (ctx *AIOContext) StateTypeName() string { + return "pkg/sentry/mm.AIOContext" +} + +func (ctx *AIOContext) StateFields() []string { + return []string{ + "results", + "maxOutstanding", + "outstanding", + } +} + +func (ctx *AIOContext) beforeSave() {} + +// +checklocksignore +func (ctx *AIOContext) StateSave(stateSinkObject state.Sink) { + ctx.beforeSave() + if !state.IsZeroValue(&ctx.dead) { + state.Failf("dead is %#v, expected zero", &ctx.dead) + } + stateSinkObject.Save(0, &ctx.results) + stateSinkObject.Save(1, &ctx.maxOutstanding) + stateSinkObject.Save(2, &ctx.outstanding) +} + +// +checklocksignore +func (ctx *AIOContext) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ctx.results) + stateSourceObject.Load(1, &ctx.maxOutstanding) + stateSourceObject.Load(2, &ctx.outstanding) + stateSourceObject.AfterLoad(ctx.afterLoad) +} + +func (m *aioMappable) StateTypeName() string { + return "pkg/sentry/mm.aioMappable" +} + +func (m *aioMappable) StateFields() []string { + return []string{ + "aioMappableRefs", + "mfp", + "fr", + } +} + +func (m *aioMappable) beforeSave() {} + +// +checklocksignore +func (m *aioMappable) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.aioMappableRefs) + stateSinkObject.Save(1, &m.mfp) + stateSinkObject.Save(2, &m.fr) +} + +func (m *aioMappable) afterLoad() {} + +// +checklocksignore +func (m *aioMappable) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.aioMappableRefs) + stateSourceObject.Load(1, &m.mfp) + stateSourceObject.Load(2, &m.fr) +} + +func (r *aioMappableRefs) StateTypeName() string { + return "pkg/sentry/mm.aioMappableRefs" +} + +func (r *aioMappableRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *aioMappableRefs) beforeSave() {} + +// +checklocksignore +func (r *aioMappableRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *aioMappableRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (s *fileRefcountSet) StateTypeName() string { + return "pkg/sentry/mm.fileRefcountSet" +} + +func (s *fileRefcountSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *fileRefcountSet) beforeSave() {} + +// +checklocksignore +func (s *fileRefcountSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *fileRefcountSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *fileRefcountSet) afterLoad() {} + +// +checklocksignore +func (s *fileRefcountSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*fileRefcountSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*fileRefcountSegmentDataSlices)) }) +} + +func (n *fileRefcountnode) StateTypeName() string { + return "pkg/sentry/mm.fileRefcountnode" +} + +func (n *fileRefcountnode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *fileRefcountnode) beforeSave() {} + +// +checklocksignore +func (n *fileRefcountnode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *fileRefcountnode) afterLoad() {} + +// +checklocksignore +func (n *fileRefcountnode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (f *fileRefcountSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/mm.fileRefcountSegmentDataSlices" +} + +func (f *fileRefcountSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (f *fileRefcountSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (f *fileRefcountSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.Start) + stateSinkObject.Save(1, &f.End) + stateSinkObject.Save(2, &f.Values) +} + +func (f *fileRefcountSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (f *fileRefcountSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.Start) + stateSourceObject.Load(1, &f.End) + stateSourceObject.Load(2, &f.Values) +} + +func (l *ioList) StateTypeName() string { + return "pkg/sentry/mm.ioList" +} + +func (l *ioList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *ioList) beforeSave() {} + +// +checklocksignore +func (l *ioList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *ioList) afterLoad() {} + +// +checklocksignore +func (l *ioList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *ioEntry) StateTypeName() string { + return "pkg/sentry/mm.ioEntry" +} + +func (e *ioEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *ioEntry) beforeSave() {} + +// +checklocksignore +func (e *ioEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *ioEntry) afterLoad() {} + +// +checklocksignore +func (e *ioEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (mm *MemoryManager) StateTypeName() string { + return "pkg/sentry/mm.MemoryManager" +} + +func (mm *MemoryManager) StateFields() []string { + return []string{ + "p", + "mfp", + "layout", + "privateRefs", + "users", + "vmas", + "brk", + "usageAS", + "lockedAS", + "dataAS", + "defMLockMode", + "pmas", + "curRSS", + "maxRSS", + "argv", + "envv", + "auxv", + "executable", + "dumpability", + "aioManager", + "sleepForActivation", + "vdsoSigReturnAddr", + "membarrierPrivateEnabled", + "membarrierRSeqEnabled", + } +} + +// +checklocksignore +func (mm *MemoryManager) StateSave(stateSinkObject state.Sink) { + mm.beforeSave() + if !state.IsZeroValue(&mm.active) { + state.Failf("active is %#v, expected zero", &mm.active) + } + if !state.IsZeroValue(&mm.captureInvalidations) { + state.Failf("captureInvalidations is %#v, expected zero", &mm.captureInvalidations) + } + stateSinkObject.Save(0, &mm.p) + stateSinkObject.Save(1, &mm.mfp) + stateSinkObject.Save(2, &mm.layout) + stateSinkObject.Save(3, &mm.privateRefs) + stateSinkObject.Save(4, &mm.users) + stateSinkObject.Save(5, &mm.vmas) + stateSinkObject.Save(6, &mm.brk) + stateSinkObject.Save(7, &mm.usageAS) + stateSinkObject.Save(8, &mm.lockedAS) + stateSinkObject.Save(9, &mm.dataAS) + stateSinkObject.Save(10, &mm.defMLockMode) + stateSinkObject.Save(11, &mm.pmas) + stateSinkObject.Save(12, &mm.curRSS) + stateSinkObject.Save(13, &mm.maxRSS) + stateSinkObject.Save(14, &mm.argv) + stateSinkObject.Save(15, &mm.envv) + stateSinkObject.Save(16, &mm.auxv) + stateSinkObject.Save(17, &mm.executable) + stateSinkObject.Save(18, &mm.dumpability) + stateSinkObject.Save(19, &mm.aioManager) + stateSinkObject.Save(20, &mm.sleepForActivation) + stateSinkObject.Save(21, &mm.vdsoSigReturnAddr) + stateSinkObject.Save(22, &mm.membarrierPrivateEnabled) + stateSinkObject.Save(23, &mm.membarrierRSeqEnabled) +} + +// +checklocksignore +func (mm *MemoryManager) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mm.p) + stateSourceObject.Load(1, &mm.mfp) + stateSourceObject.Load(2, &mm.layout) + stateSourceObject.Load(3, &mm.privateRefs) + stateSourceObject.Load(4, &mm.users) + stateSourceObject.Load(5, &mm.vmas) + stateSourceObject.Load(6, &mm.brk) + stateSourceObject.Load(7, &mm.usageAS) + stateSourceObject.Load(8, &mm.lockedAS) + stateSourceObject.Load(9, &mm.dataAS) + stateSourceObject.Load(10, &mm.defMLockMode) + stateSourceObject.Load(11, &mm.pmas) + stateSourceObject.Load(12, &mm.curRSS) + stateSourceObject.Load(13, &mm.maxRSS) + stateSourceObject.Load(14, &mm.argv) + stateSourceObject.Load(15, &mm.envv) + stateSourceObject.Load(16, &mm.auxv) + stateSourceObject.Load(17, &mm.executable) + stateSourceObject.Load(18, &mm.dumpability) + stateSourceObject.Load(19, &mm.aioManager) + stateSourceObject.Load(20, &mm.sleepForActivation) + stateSourceObject.Load(21, &mm.vdsoSigReturnAddr) + stateSourceObject.Load(22, &mm.membarrierPrivateEnabled) + stateSourceObject.Load(23, &mm.membarrierRSeqEnabled) + stateSourceObject.AfterLoad(mm.afterLoad) +} + +func (vma *vma) StateTypeName() string { + return "pkg/sentry/mm.vma" +} + +func (vma *vma) StateFields() []string { + return []string{ + "mappable", + "off", + "realPerms", + "dontfork", + "mlockMode", + "numaPolicy", + "numaNodemask", + "id", + "hint", + } +} + +func (vma *vma) beforeSave() {} + +// +checklocksignore +func (vma *vma) StateSave(stateSinkObject state.Sink) { + vma.beforeSave() + var realPermsValue int + realPermsValue = vma.saveRealPerms() + stateSinkObject.SaveValue(2, realPermsValue) + stateSinkObject.Save(0, &vma.mappable) + stateSinkObject.Save(1, &vma.off) + stateSinkObject.Save(3, &vma.dontfork) + stateSinkObject.Save(4, &vma.mlockMode) + stateSinkObject.Save(5, &vma.numaPolicy) + stateSinkObject.Save(6, &vma.numaNodemask) + stateSinkObject.Save(7, &vma.id) + stateSinkObject.Save(8, &vma.hint) +} + +func (vma *vma) afterLoad() {} + +// +checklocksignore +func (vma *vma) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &vma.mappable) + stateSourceObject.Load(1, &vma.off) + stateSourceObject.Load(3, &vma.dontfork) + stateSourceObject.Load(4, &vma.mlockMode) + stateSourceObject.Load(5, &vma.numaPolicy) + stateSourceObject.Load(6, &vma.numaNodemask) + stateSourceObject.Load(7, &vma.id) + stateSourceObject.Load(8, &vma.hint) + stateSourceObject.LoadValue(2, new(int), func(y interface{}) { vma.loadRealPerms(y.(int)) }) +} + +func (p *pma) StateTypeName() string { + return "pkg/sentry/mm.pma" +} + +func (p *pma) StateFields() []string { + return []string{ + "off", + "translatePerms", + "effectivePerms", + "maxPerms", + "needCOW", + "private", + } +} + +func (p *pma) beforeSave() {} + +// +checklocksignore +func (p *pma) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.off) + stateSinkObject.Save(1, &p.translatePerms) + stateSinkObject.Save(2, &p.effectivePerms) + stateSinkObject.Save(3, &p.maxPerms) + stateSinkObject.Save(4, &p.needCOW) + stateSinkObject.Save(5, &p.private) +} + +func (p *pma) afterLoad() {} + +// +checklocksignore +func (p *pma) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.off) + stateSourceObject.Load(1, &p.translatePerms) + stateSourceObject.Load(2, &p.effectivePerms) + stateSourceObject.Load(3, &p.maxPerms) + stateSourceObject.Load(4, &p.needCOW) + stateSourceObject.Load(5, &p.private) +} + +func (p *privateRefs) StateTypeName() string { + return "pkg/sentry/mm.privateRefs" +} + +func (p *privateRefs) StateFields() []string { + return []string{ + "refs", + } +} + +func (p *privateRefs) beforeSave() {} + +// +checklocksignore +func (p *privateRefs) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.refs) +} + +func (p *privateRefs) afterLoad() {} + +// +checklocksignore +func (p *privateRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.refs) +} + +func (s *pmaSet) StateTypeName() string { + return "pkg/sentry/mm.pmaSet" +} + +func (s *pmaSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *pmaSet) beforeSave() {} + +// +checklocksignore +func (s *pmaSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *pmaSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *pmaSet) afterLoad() {} + +// +checklocksignore +func (s *pmaSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*pmaSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*pmaSegmentDataSlices)) }) +} + +func (n *pmanode) StateTypeName() string { + return "pkg/sentry/mm.pmanode" +} + +func (n *pmanode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *pmanode) beforeSave() {} + +// +checklocksignore +func (n *pmanode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *pmanode) afterLoad() {} + +// +checklocksignore +func (n *pmanode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (p *pmaSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/mm.pmaSegmentDataSlices" +} + +func (p *pmaSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (p *pmaSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (p *pmaSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.Start) + stateSinkObject.Save(1, &p.End) + stateSinkObject.Save(2, &p.Values) +} + +func (p *pmaSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (p *pmaSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.Start) + stateSourceObject.Load(1, &p.End) + stateSourceObject.Load(2, &p.Values) +} + +func (m *SpecialMappable) StateTypeName() string { + return "pkg/sentry/mm.SpecialMappable" +} + +func (m *SpecialMappable) StateFields() []string { + return []string{ + "SpecialMappableRefs", + "mfp", + "fr", + "name", + } +} + +func (m *SpecialMappable) beforeSave() {} + +// +checklocksignore +func (m *SpecialMappable) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.SpecialMappableRefs) + stateSinkObject.Save(1, &m.mfp) + stateSinkObject.Save(2, &m.fr) + stateSinkObject.Save(3, &m.name) +} + +func (m *SpecialMappable) afterLoad() {} + +// +checklocksignore +func (m *SpecialMappable) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.SpecialMappableRefs) + stateSourceObject.Load(1, &m.mfp) + stateSourceObject.Load(2, &m.fr) + stateSourceObject.Load(3, &m.name) +} + +func (r *SpecialMappableRefs) StateTypeName() string { + return "pkg/sentry/mm.SpecialMappableRefs" +} + +func (r *SpecialMappableRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *SpecialMappableRefs) beforeSave() {} + +// +checklocksignore +func (r *SpecialMappableRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *SpecialMappableRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (s *vmaSet) StateTypeName() string { + return "pkg/sentry/mm.vmaSet" +} + +func (s *vmaSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *vmaSet) beforeSave() {} + +// +checklocksignore +func (s *vmaSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *vmaSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *vmaSet) afterLoad() {} + +// +checklocksignore +func (s *vmaSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*vmaSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*vmaSegmentDataSlices)) }) +} + +func (n *vmanode) StateTypeName() string { + return "pkg/sentry/mm.vmanode" +} + +func (n *vmanode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *vmanode) beforeSave() {} + +// +checklocksignore +func (n *vmanode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *vmanode) afterLoad() {} + +// +checklocksignore +func (n *vmanode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (v *vmaSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/mm.vmaSegmentDataSlices" +} + +func (v *vmaSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (v *vmaSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (v *vmaSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + v.beforeSave() + stateSinkObject.Save(0, &v.Start) + stateSinkObject.Save(1, &v.End) + stateSinkObject.Save(2, &v.Values) +} + +func (v *vmaSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (v *vmaSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &v.Start) + stateSourceObject.Load(1, &v.End) + stateSourceObject.Load(2, &v.Values) +} + +func init() { + state.Register((*aioManager)(nil)) + state.Register((*ioResult)(nil)) + state.Register((*AIOContext)(nil)) + state.Register((*aioMappable)(nil)) + state.Register((*aioMappableRefs)(nil)) + state.Register((*fileRefcountSet)(nil)) + state.Register((*fileRefcountnode)(nil)) + state.Register((*fileRefcountSegmentDataSlices)(nil)) + state.Register((*ioList)(nil)) + state.Register((*ioEntry)(nil)) + state.Register((*MemoryManager)(nil)) + state.Register((*vma)(nil)) + state.Register((*pma)(nil)) + state.Register((*privateRefs)(nil)) + state.Register((*pmaSet)(nil)) + state.Register((*pmanode)(nil)) + state.Register((*pmaSegmentDataSlices)(nil)) + state.Register((*SpecialMappable)(nil)) + state.Register((*SpecialMappableRefs)(nil)) + state.Register((*vmaSet)(nil)) + state.Register((*vmanode)(nil)) + state.Register((*vmaSegmentDataSlices)(nil)) +} diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go deleted file mode 100644 index 84cb8158d..000000000 --- a/pkg/sentry/mm/mm_test.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mm - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/usermem" -) - -func testMemoryManager(ctx context.Context) *MemoryManager { - p := platform.FromContext(ctx) - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - mm := NewMemoryManager(p, mfp, false) - mm.layout = arch.MmapLayout{ - MinAddr: p.MinUserAddress(), - MaxAddr: p.MaxUserAddress(), - BottomUpBase: p.MinUserAddress(), - TopDownBase: p.MaxUserAddress(), - } - return mm -} - -func (mm *MemoryManager) realUsageAS() uint64 { - return uint64(mm.vmas.Span()) -} - -func TestUsageASUpdates(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * hostarch.PageSize, - Private: true, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - realUsage := mm.realUsageAS() - if mm.usageAS != realUsage { - t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) - } - - mm.MUnmap(ctx, addr, hostarch.PageSize) - realUsage = mm.realUsageAS() - if mm.usageAS != realUsage { - t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) - } -} - -func (mm *MemoryManager) realDataAS() uint64 { - var sz uint64 - for seg := mm.vmas.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - vma := seg.Value() - if vma.isPrivateDataLocked() { - sz += uint64(seg.Range().Length()) - } - } - return sz -} - -func TestDataASUpdates(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 3 * hostarch.PageSize, - Private: true, - Perms: hostarch.Write, - MaxPerms: hostarch.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - if mm.dataAS == 0 { - t.Fatalf("dataAS is 0, wanted not 0") - } - realDataAS := mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MUnmap(ctx, addr, hostarch.PageSize) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MProtect(addr+hostarch.PageSize, hostarch.PageSize, hostarch.Read, false) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MRemap(ctx, addr+2*hostarch.PageSize, hostarch.PageSize, 2*hostarch.PageSize, MRemapOpts{ - Move: MRemapMayMove, - }) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } -} - -func TestBrkDataLimitUpdates(t *testing.T) { - limitSet := limits.NewLimitSet() - limitSet.Set(limits.Data, limits.Limit{}, true /* privileged */) // zero RLIMIT_DATA - - ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - // Try to extend the brk by one page and expect doing so to fail. - oldBrk, _ := mm.Brk(ctx, 0) - if newBrk, _ := mm.Brk(ctx, oldBrk+hostarch.PageSize); newBrk != oldBrk { - t.Errorf("brk() increased data segment above RLIMIT_DATA (old brk = %#x, new brk = %#x", oldBrk, newBrk) - } -} - -// TestIOAfterUnmap ensures that IO fails after unmap. -func TestIOAfterUnmap(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: hostarch.PageSize, - Private: true, - Perms: hostarch.Read, - MaxPerms: hostarch.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - - // IO works before munmap. - b := make([]byte, 1) - n, err := mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != nil { - t.Errorf("CopyIn got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyIn got %d want 1", n) - } - - err = mm.MUnmap(ctx, addr, hostarch.PageSize) - if err != nil { - t.Fatalf("MUnmap got err %v want nil", err) - } - - n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if !linuxerr.Equals(linuxerr.EFAULT, err) { - t.Errorf("CopyIn got err %v want EFAULT", err) - } - if n != 0 { - t.Errorf("CopyIn got %d want 0", n) - } -} - -// TestIOAfterMProtect tests IO interaction with mprotect permissions. -func TestIOAfterMProtect(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: hostarch.PageSize, - Private: true, - Perms: hostarch.ReadWrite, - MaxPerms: hostarch.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - - // Writing works before mprotect. - b := make([]byte, 1) - n, err := mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != nil { - t.Errorf("CopyOut got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyOut got %d want 1", n) - } - - err = mm.MProtect(addr, hostarch.PageSize, hostarch.Read, false) - if err != nil { - t.Errorf("MProtect got err %v want nil", err) - } - - // Without IgnorePermissions, CopyOut should no longer succeed. - n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if !linuxerr.Equals(linuxerr.EFAULT, err) { - t.Errorf("CopyOut got err %v want EFAULT", err) - } - if n != 0 { - t.Errorf("CopyOut got %d want 0", n) - } - - // With IgnorePermissions, CopyOut should succeed despite mprotect. - n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{ - IgnorePermissions: true, - }) - if err != nil { - t.Errorf("CopyOut got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyOut got %d want 1", n) - } -} - -// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be -// prepared after destruction. -func TestAIOPrepareAfterDestroy(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - id, err := mm.NewAIOContext(ctx, 1) - if err != nil { - t.Fatalf("mm.NewAIOContext got err %v want nil", err) - } - aioCtx, ok := mm.LookupAIOContext(ctx, id) - if !ok { - t.Fatalf("AIOContext not found") - } - mm.DestroyAIOContext(ctx, id) - - // Prepare should fail because aioCtx should be destroyed. - if err := aioCtx.Prepare(); !linuxerr.Equals(linuxerr.EINVAL, err) { - t.Errorf("aioCtx.Prepare got err %v want nil", err) - } else if err == nil { - aioCtx.CancelPendingRequest() - } -} - -// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be -// looked up after memory manager is destroyed. -func TestAIOLookupAfterDestroy(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - - id, err := mm.NewAIOContext(ctx, 1) - if err != nil { - mm.DecUsers(ctx) - t.Fatalf("mm.NewAIOContext got err %v want nil", err) - } - mm.DecUsers(ctx) // This destroys the AIOContext manager. - - if _, ok := mm.LookupAIOContext(ctx, id); ok { - t.Errorf("AIOContext found even after AIOContext manager is destroyed") - } -} diff --git a/pkg/sentry/mm/pma_set.go b/pkg/sentry/mm/pma_set.go new file mode 100644 index 000000000..14a7c83b9 --- /dev/null +++ b/pkg/sentry/mm/pma_set.go @@ -0,0 +1,1647 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/hostarch" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const pmatrackGaps = 0 + +var _ = uint8(pmatrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type pmadynamicGap [pmatrackGaps]__generics_imported0.Addr + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *pmadynamicGap) Get() __generics_imported0.Addr { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *pmadynamicGap) Set(v __generics_imported0.Addr) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + pmaminDegree = 8 + + pmamaxDegree = 2 * pmaminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type pmaSet struct { + root pmanode `state:".(*pmaSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *pmaSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *pmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *pmaSet) Span() __generics_imported0.Addr { + var sz __generics_imported0.Addr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *pmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz __generics_imported0.Addr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *pmaSet) FirstSegment() pmaIterator { + if s.root.nrSegments == 0 { + return pmaIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *pmaSet) LastSegment() pmaIterator { + if s.root.nrSegments == 0 { + return pmaIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *pmaSet) FirstGap() pmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return pmaGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *pmaSet) LastGap() pmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return pmaGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *pmaSet) Find(key __generics_imported0.Addr) (pmaIterator, pmaGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return pmaIterator{n, i}, pmaGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return pmaIterator{}, pmaGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *pmaSet) FindSegment(key __generics_imported0.Addr) pmaIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *pmaSet) LowerBoundSegment(min __generics_imported0.Addr) pmaIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *pmaSet) UpperBoundSegment(max __generics_imported0.Addr) pmaIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *pmaSet) FindGap(key __generics_imported0.Addr) pmaGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *pmaSet) LowerBoundGap(min __generics_imported0.Addr) pmaGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *pmaSet) UpperBoundGap(max __generics_imported0.Addr) pmaGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *pmaSet) Add(r __generics_imported0.AddrRange, val pma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *pmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val pma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *pmaSet) Insert(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := pmatrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (pmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := pmatrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *pmaSet) InsertWithoutMerging(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *pmaSet) InsertWithoutMergingUnchecked(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := pmatrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return pmaIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *pmaSet) Remove(seg pmaIterator) pmaGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if pmatrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + pmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if pmatrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(pmaGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *pmaSet) RemoveAll() { + s.root = pmanode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *pmaSet) RemoveRange(r __generics_imported0.AddrRange) pmaGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *pmaSet) Merge(first, second pmaIterator) pmaIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *pmaSet) MergeUnchecked(first, second pmaIterator) pmaIterator { + if first.End() == second.Start() { + if mval, ok := (pmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return pmaIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *pmaSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *pmaSet) MergeRange(r __generics_imported0.AddrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *pmaSet) MergeAdjacent(r __generics_imported0.AddrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *pmaSet) Split(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *pmaSet) SplitUnchecked(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) { + val1, val2 := (pmaSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *pmaSet) SplitAt(split __generics_imported0.Addr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *pmaSet) Isolate(seg pmaIterator, r __generics_imported0.AddrRange) pmaIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *pmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg pmaIterator)) pmaGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return pmaGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return pmaGapIterator{} + } + } +} + +// +stateify savable +type pmanode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *pmanode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap pmadynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [pmamaxDegree - 1]__generics_imported0.AddrRange + values [pmamaxDegree - 1]pma + children [pmamaxDegree]*pmanode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *pmanode) firstSegment() pmaIterator { + for n.hasChildren { + n = n.children[0] + } + return pmaIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *pmanode) lastSegment() pmaIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return pmaIterator{n, n.nrSegments - 1} +} + +func (n *pmanode) prevSibling() *pmanode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *pmanode) nextSibling() *pmanode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator { + if n.nrSegments < pmamaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:pmaminDegree-1], n.keys[:pmaminDegree-1]) + copy(left.values[:pmaminDegree-1], n.values[:pmaminDegree-1]) + copy(right.keys[:pmaminDegree-1], n.keys[pmaminDegree:]) + copy(right.values[:pmaminDegree-1], n.values[pmaminDegree:]) + n.keys[0], n.values[0] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1] + pmazeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:pmaminDegree], n.children[:pmaminDegree]) + copy(right.children[:pmaminDegree], n.children[pmaminDegree:]) + pmazeroNodeSlice(n.children[2:]) + for i := 0; i < pmaminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if pmatrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < pmaminDegree { + return pmaGapIterator{left, gap.index} + } + return pmaGapIterator{right, gap.index - pmaminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:pmaminDegree-1], n.keys[pmaminDegree:]) + copy(sibling.values[:pmaminDegree-1], n.values[pmaminDegree:]) + pmazeroValueSlice(n.values[pmaminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:pmaminDegree], n.children[pmaminDegree:]) + pmazeroNodeSlice(n.children[pmaminDegree:]) + for i := 0; i < pmaminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = pmaminDegree - 1 + + if pmatrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < pmaminDegree { + return gap + } + return pmaGapIterator{sibling, gap.index - pmaminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator { + for { + if n.nrSegments >= pmaminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if pmatrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return pmaGapIterator{n, 0} + } + if gap.node == n { + return pmaGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if pmatrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return pmaGapIterator{n, n.nrSegments} + } + return pmaGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return pmaGapIterator{p, gap.index} + } + if gap.node == right { + return pmaGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *pmanode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = pmaGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + pmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if pmatrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *pmanode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *pmanode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *pmanode) calculateMaxGapLeaf() __generics_imported0.Addr { + max := pmaGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (pmaGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *pmanode) calculateMaxGapInternal() __generics_imported0.Addr { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *pmanode) searchFirstLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator { + if n.maxGap.Get() < minSize { + return pmaGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := pmaGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *pmanode) searchLastLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator { + if n.maxGap.Get() < minSize { + return pmaGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := pmaGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type pmaIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *pmanode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg pmaIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg pmaIterator) Range() __generics_imported0.AddrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg pmaIterator) Start() __generics_imported0.Addr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg pmaIterator) End() __generics_imported0.Addr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg pmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetRange(r __generics_imported0.AddrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg pmaIterator) SetStartUnchecked(start __generics_imported0.Addr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetStart(start __generics_imported0.Addr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg pmaIterator) SetEndUnchecked(end __generics_imported0.Addr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetEnd(end __generics_imported0.Addr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg pmaIterator) Value() pma { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg pmaIterator) ValuePtr() *pma { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg pmaIterator) SetValue(val pma) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg pmaIterator) PrevSegment() pmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return pmaIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return pmaIterator{} + } + return pmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg pmaIterator) NextSegment() pmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return pmaIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return pmaIterator{} + } + return pmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg pmaIterator) PrevGap() pmaGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return pmaGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg pmaIterator) NextGap() pmaGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return pmaGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg pmaIterator) PrevNonEmpty() (pmaIterator, pmaGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return pmaIterator{}, gap + } + return gap.PrevSegment(), pmaGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg pmaIterator) NextNonEmpty() (pmaIterator, pmaGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return pmaIterator{}, gap + } + return gap.NextSegment(), pmaGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type pmaGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *pmanode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap pmaGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap pmaGapIterator) Range() __generics_imported0.AddrRange { + return __generics_imported0.AddrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap pmaGapIterator) Start() __generics_imported0.Addr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return pmaSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap pmaGapIterator) End() __generics_imported0.Addr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return pmaSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap pmaGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap pmaGapIterator) PrevSegment() pmaIterator { + return pmasegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap pmaGapIterator) NextSegment() pmaIterator { + return pmasegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap pmaGapIterator) PrevGap() pmaGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return pmaGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap pmaGapIterator) NextGap() pmaGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return pmaGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap pmaGapIterator) NextLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator { + if pmatrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap pmaGapIterator) nextLargeEnoughGapHelper(minSize __generics_imported0.Addr) pmaGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return pmaGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap pmaGapIterator) PrevLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator { + if pmatrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap pmaGapIterator) prevLargeEnoughGapHelper(minSize __generics_imported0.Addr) pmaGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return pmaGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func pmasegmentBeforePosition(n *pmanode, i int) pmaIterator { + for i == 0 { + if n.parent == nil { + return pmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return pmaIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func pmasegmentAfterPosition(n *pmanode, i int) pmaIterator { + for i == n.nrSegments { + if n.parent == nil { + return pmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return pmaIterator{n, i} +} + +func pmazeroValueSlice(slice []pma) { + + for i := range slice { + pmaSetFunctions{}.ClearValue(&slice[i]) + } +} + +func pmazeroNodeSlice(slice []*pmanode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *pmaSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *pmanode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *pmanode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if pmatrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type pmaSegmentDataSlices struct { + Start []__generics_imported0.Addr + End []__generics_imported0.Addr + Values []pma +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *pmaSet) ExportSortedSlices() *pmaSegmentDataSlices { + var sds pmaSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *pmaSet) ImportSortedSlices(sds *pmaSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *pmaSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.AddrRange, pma) error) error { + havePrev := false + prev := __generics_imported0.Addr(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *pmaSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *pmaSet) saveRoot() *pmaSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *pmaSet) loadRoot(sds *pmaSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/mm/special_mappable_refs.go b/pkg/sentry/mm/special_mappable_refs.go new file mode 100644 index 000000000..386a9fa3b --- /dev/null +++ b/pkg/sentry/mm/special_mappable_refs.go @@ -0,0 +1,140 @@ +package mm + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const SpecialMappableenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var SpecialMappableobj *SpecialMappable + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type SpecialMappableRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *SpecialMappableRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *SpecialMappableRefs) RefType() string { + return fmt.Sprintf("%T", SpecialMappableobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *SpecialMappableRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *SpecialMappableRefs) LogRefs() bool { + return SpecialMappableenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *SpecialMappableRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *SpecialMappableRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if SpecialMappableenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *SpecialMappableRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if SpecialMappableenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *SpecialMappableRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if SpecialMappableenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *SpecialMappableRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/mm/vma_set.go b/pkg/sentry/mm/vma_set.go new file mode 100644 index 000000000..298c86629 --- /dev/null +++ b/pkg/sentry/mm/vma_set.go @@ -0,0 +1,1647 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/hostarch" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const vmatrackGaps = 1 + +var _ = uint8(vmatrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type vmadynamicGap [vmatrackGaps]__generics_imported0.Addr + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *vmadynamicGap) Get() __generics_imported0.Addr { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *vmadynamicGap) Set(v __generics_imported0.Addr) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + vmaminDegree = 8 + + vmamaxDegree = 2 * vmaminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type vmaSet struct { + root vmanode `state:".(*vmaSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *vmaSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *vmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *vmaSet) Span() __generics_imported0.Addr { + var sz __generics_imported0.Addr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *vmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz __generics_imported0.Addr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *vmaSet) FirstSegment() vmaIterator { + if s.root.nrSegments == 0 { + return vmaIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *vmaSet) LastSegment() vmaIterator { + if s.root.nrSegments == 0 { + return vmaIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *vmaSet) FirstGap() vmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return vmaGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *vmaSet) LastGap() vmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return vmaGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *vmaSet) Find(key __generics_imported0.Addr) (vmaIterator, vmaGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return vmaIterator{n, i}, vmaGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return vmaIterator{}, vmaGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *vmaSet) FindSegment(key __generics_imported0.Addr) vmaIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *vmaSet) LowerBoundSegment(min __generics_imported0.Addr) vmaIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *vmaSet) UpperBoundSegment(max __generics_imported0.Addr) vmaIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *vmaSet) FindGap(key __generics_imported0.Addr) vmaGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *vmaSet) LowerBoundGap(min __generics_imported0.Addr) vmaGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *vmaSet) UpperBoundGap(max __generics_imported0.Addr) vmaGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *vmaSet) Add(r __generics_imported0.AddrRange, val vma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *vmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val vma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *vmaSet) Insert(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := vmatrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (vmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := vmatrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *vmaSet) InsertWithoutMerging(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *vmaSet) InsertWithoutMergingUnchecked(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := vmatrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return vmaIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *vmaSet) Remove(seg vmaIterator) vmaGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if vmatrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + vmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if vmatrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(vmaGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *vmaSet) RemoveAll() { + s.root = vmanode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *vmaSet) RemoveRange(r __generics_imported0.AddrRange) vmaGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *vmaSet) Merge(first, second vmaIterator) vmaIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *vmaSet) MergeUnchecked(first, second vmaIterator) vmaIterator { + if first.End() == second.Start() { + if mval, ok := (vmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return vmaIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *vmaSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *vmaSet) MergeRange(r __generics_imported0.AddrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *vmaSet) MergeAdjacent(r __generics_imported0.AddrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *vmaSet) Split(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *vmaSet) SplitUnchecked(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) { + val1, val2 := (vmaSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *vmaSet) SplitAt(split __generics_imported0.Addr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *vmaSet) Isolate(seg vmaIterator, r __generics_imported0.AddrRange) vmaIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *vmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg vmaIterator)) vmaGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return vmaGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return vmaGapIterator{} + } + } +} + +// +stateify savable +type vmanode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *vmanode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap vmadynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [vmamaxDegree - 1]__generics_imported0.AddrRange + values [vmamaxDegree - 1]vma + children [vmamaxDegree]*vmanode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *vmanode) firstSegment() vmaIterator { + for n.hasChildren { + n = n.children[0] + } + return vmaIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *vmanode) lastSegment() vmaIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return vmaIterator{n, n.nrSegments - 1} +} + +func (n *vmanode) prevSibling() *vmanode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *vmanode) nextSibling() *vmanode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *vmanode) rebalanceBeforeInsert(gap vmaGapIterator) vmaGapIterator { + if n.nrSegments < vmamaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:vmaminDegree-1], n.keys[:vmaminDegree-1]) + copy(left.values[:vmaminDegree-1], n.values[:vmaminDegree-1]) + copy(right.keys[:vmaminDegree-1], n.keys[vmaminDegree:]) + copy(right.values[:vmaminDegree-1], n.values[vmaminDegree:]) + n.keys[0], n.values[0] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1] + vmazeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:vmaminDegree], n.children[:vmaminDegree]) + copy(right.children[:vmaminDegree], n.children[vmaminDegree:]) + vmazeroNodeSlice(n.children[2:]) + for i := 0; i < vmaminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if vmatrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < vmaminDegree { + return vmaGapIterator{left, gap.index} + } + return vmaGapIterator{right, gap.index - vmaminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:vmaminDegree-1], n.keys[vmaminDegree:]) + copy(sibling.values[:vmaminDegree-1], n.values[vmaminDegree:]) + vmazeroValueSlice(n.values[vmaminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:vmaminDegree], n.children[vmaminDegree:]) + vmazeroNodeSlice(n.children[vmaminDegree:]) + for i := 0; i < vmaminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = vmaminDegree - 1 + + if vmatrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < vmaminDegree { + return gap + } + return vmaGapIterator{sibling, gap.index - vmaminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *vmanode) rebalanceAfterRemove(gap vmaGapIterator) vmaGapIterator { + for { + if n.nrSegments >= vmaminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if vmatrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return vmaGapIterator{n, 0} + } + if gap.node == n { + return vmaGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if vmatrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return vmaGapIterator{n, n.nrSegments} + } + return vmaGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return vmaGapIterator{p, gap.index} + } + if gap.node == right { + return vmaGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *vmanode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = vmaGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + vmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if vmatrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *vmanode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *vmanode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *vmanode) calculateMaxGapLeaf() __generics_imported0.Addr { + max := vmaGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (vmaGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *vmanode) calculateMaxGapInternal() __generics_imported0.Addr { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *vmanode) searchFirstLargeEnoughGap(minSize __generics_imported0.Addr) vmaGapIterator { + if n.maxGap.Get() < minSize { + return vmaGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := vmaGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *vmanode) searchLastLargeEnoughGap(minSize __generics_imported0.Addr) vmaGapIterator { + if n.maxGap.Get() < minSize { + return vmaGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := vmaGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type vmaIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *vmanode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg vmaIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg vmaIterator) Range() __generics_imported0.AddrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg vmaIterator) Start() __generics_imported0.Addr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg vmaIterator) End() __generics_imported0.Addr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg vmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetRange(r __generics_imported0.AddrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg vmaIterator) SetStartUnchecked(start __generics_imported0.Addr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetStart(start __generics_imported0.Addr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg vmaIterator) SetEndUnchecked(end __generics_imported0.Addr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetEnd(end __generics_imported0.Addr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg vmaIterator) Value() vma { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg vmaIterator) ValuePtr() *vma { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg vmaIterator) SetValue(val vma) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg vmaIterator) PrevSegment() vmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return vmaIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return vmaIterator{} + } + return vmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg vmaIterator) NextSegment() vmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return vmaIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return vmaIterator{} + } + return vmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg vmaIterator) PrevGap() vmaGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return vmaGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg vmaIterator) NextGap() vmaGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return vmaGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg vmaIterator) PrevNonEmpty() (vmaIterator, vmaGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return vmaIterator{}, gap + } + return gap.PrevSegment(), vmaGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg vmaIterator) NextNonEmpty() (vmaIterator, vmaGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return vmaIterator{}, gap + } + return gap.NextSegment(), vmaGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type vmaGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *vmanode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap vmaGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap vmaGapIterator) Range() __generics_imported0.AddrRange { + return __generics_imported0.AddrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap vmaGapIterator) Start() __generics_imported0.Addr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return vmaSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap vmaGapIterator) End() __generics_imported0.Addr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return vmaSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap vmaGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap vmaGapIterator) PrevSegment() vmaIterator { + return vmasegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap vmaGapIterator) NextSegment() vmaIterator { + return vmasegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap vmaGapIterator) PrevGap() vmaGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return vmaGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap vmaGapIterator) NextGap() vmaGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return vmaGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap vmaGapIterator) NextLargeEnoughGap(minSize __generics_imported0.Addr) vmaGapIterator { + if vmatrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap vmaGapIterator) nextLargeEnoughGapHelper(minSize __generics_imported0.Addr) vmaGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return vmaGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap vmaGapIterator) PrevLargeEnoughGap(minSize __generics_imported0.Addr) vmaGapIterator { + if vmatrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap vmaGapIterator) prevLargeEnoughGapHelper(minSize __generics_imported0.Addr) vmaGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return vmaGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func vmasegmentBeforePosition(n *vmanode, i int) vmaIterator { + for i == 0 { + if n.parent == nil { + return vmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return vmaIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func vmasegmentAfterPosition(n *vmanode, i int) vmaIterator { + for i == n.nrSegments { + if n.parent == nil { + return vmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return vmaIterator{n, i} +} + +func vmazeroValueSlice(slice []vma) { + + for i := range slice { + vmaSetFunctions{}.ClearValue(&slice[i]) + } +} + +func vmazeroNodeSlice(slice []*vmanode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *vmaSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *vmanode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *vmanode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if vmatrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type vmaSegmentDataSlices struct { + Start []__generics_imported0.Addr + End []__generics_imported0.Addr + Values []vma +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *vmaSet) ExportSortedSlices() *vmaSegmentDataSlices { + var sds vmaSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *vmaSet) ImportSortedSlices(sds *vmaSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *vmaSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.AddrRange, vma) error) error { + havePrev := false + prev := __generics_imported0.Addr(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *vmaSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *vmaSet) saveRoot() *vmaSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *vmaSet) loadRoot(sds *vmaSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD deleted file mode 100644 index 496a9fd97..000000000 --- a/pkg/sentry/pgalloc/BUILD +++ /dev/null @@ -1,111 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "evictable_range", - out = "evictable_range.go", - package = "pgalloc", - prefix = "Evictable", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "evictable_range_set", - out = "evictable_range_set.go", - package = "pgalloc", - prefix = "evictableRange", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "EvictableRange", - "Value": "evictableRangeSetValue", - "Functions": "evictableRangeSetFunctions", - }, -) - -go_template_instance( - name = "usage_set", - out = "usage_set.go", - consts = { - "minDegree": "10", - "trackGaps": "1", - }, - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - }, - package = "pgalloc", - prefix = "usage", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.FileRange", - "Value": "usageInfo", - "Functions": "usageSetFunctions", - }, -) - -go_template_instance( - name = "reclaim_set", - out = "reclaim_set.go", - consts = { - "minDegree": "10", - }, - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - }, - package = "pgalloc", - prefix = "reclaim", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.FileRange", - "Value": "reclaimSetValue", - "Functions": "reclaimSetFunctions", - }, -) - -go_library( - name = "pgalloc", - srcs = [ - "context.go", - "evictable_range.go", - "evictable_range_set.go", - "pgalloc.go", - "pgalloc_unsafe.go", - "reclaim_set.go", - "save_restore.go", - "usage_set.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/memutil", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/hostmm", - "//pkg/sentry/memmap", - "//pkg/sentry/usage", - "//pkg/state", - "//pkg/state/wire", - "//pkg/sync", - "//pkg/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "pgalloc_test", - size = "small", - srcs = ["pgalloc_test.go"], - library = ":pgalloc", - deps = ["//pkg/hostarch"], -) diff --git a/pkg/sentry/pgalloc/evictable_range.go b/pkg/sentry/pgalloc/evictable_range.go new file mode 100644 index 000000000..79f513ac2 --- /dev/null +++ b/pkg/sentry/pgalloc/evictable_range.go @@ -0,0 +1,76 @@ +package pgalloc + +// A Range represents a contiguous range of T. +// +// +stateify savable +type EvictableRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +// +//go:nosplit +func (r EvictableRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +// +//go:nosplit +func (r EvictableRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +// +//go:nosplit +func (r EvictableRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +// +//go:nosplit +func (r EvictableRange) Overlaps(r2 EvictableRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +// +//go:nosplit +func (r EvictableRange) IsSupersetOf(r2 EvictableRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +// +//go:nosplit +func (r EvictableRange) Intersect(r2 EvictableRange) EvictableRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +// +//go:nosplit +func (r EvictableRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/pgalloc/evictable_range_set.go b/pkg/sentry/pgalloc/evictable_range_set.go new file mode 100644 index 000000000..c0c712b21 --- /dev/null +++ b/pkg/sentry/pgalloc/evictable_range_set.go @@ -0,0 +1,1643 @@ +package pgalloc + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const evictableRangetrackGaps = 0 + +var _ = uint8(evictableRangetrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type evictableRangedynamicGap [evictableRangetrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *evictableRangedynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *evictableRangedynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + evictableRangeminDegree = 3 + + evictableRangemaxDegree = 2 * evictableRangeminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type evictableRangeSet struct { + root evictableRangenode `state:".(*evictableRangeSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *evictableRangeSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *evictableRangeSet) IsEmptyRange(r EvictableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *evictableRangeSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *evictableRangeSet) SpanRange(r EvictableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *evictableRangeSet) FirstSegment() evictableRangeIterator { + if s.root.nrSegments == 0 { + return evictableRangeIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *evictableRangeSet) LastSegment() evictableRangeIterator { + if s.root.nrSegments == 0 { + return evictableRangeIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *evictableRangeSet) FirstGap() evictableRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return evictableRangeGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *evictableRangeSet) LastGap() evictableRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return evictableRangeGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *evictableRangeSet) Find(key uint64) (evictableRangeIterator, evictableRangeGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return evictableRangeIterator{n, i}, evictableRangeGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return evictableRangeIterator{}, evictableRangeGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *evictableRangeSet) FindSegment(key uint64) evictableRangeIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *evictableRangeSet) LowerBoundSegment(min uint64) evictableRangeIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *evictableRangeSet) UpperBoundSegment(max uint64) evictableRangeIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *evictableRangeSet) FindGap(key uint64) evictableRangeGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *evictableRangeSet) LowerBoundGap(min uint64) evictableRangeGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *evictableRangeSet) UpperBoundGap(max uint64) evictableRangeGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *evictableRangeSet) Add(r EvictableRange, val evictableRangeSetValue) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *evictableRangeSet) AddWithoutMerging(r EvictableRange, val evictableRangeSetValue) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *evictableRangeSet) Insert(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := evictableRangetrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (evictableRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := evictableRangetrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *evictableRangeSet) InsertWithoutMerging(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *evictableRangeSet) InsertWithoutMergingUnchecked(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := evictableRangetrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return evictableRangeIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *evictableRangeSet) Remove(seg evictableRangeIterator) evictableRangeGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if evictableRangetrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + evictableRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if evictableRangetrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(evictableRangeGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *evictableRangeSet) RemoveAll() { + s.root = evictableRangenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *evictableRangeSet) RemoveRange(r EvictableRange) evictableRangeGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *evictableRangeSet) Merge(first, second evictableRangeIterator) evictableRangeIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *evictableRangeSet) MergeUnchecked(first, second evictableRangeIterator) evictableRangeIterator { + if first.End() == second.Start() { + if mval, ok := (evictableRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return evictableRangeIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *evictableRangeSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *evictableRangeSet) MergeRange(r EvictableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *evictableRangeSet) MergeAdjacent(r EvictableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *evictableRangeSet) Split(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *evictableRangeSet) SplitUnchecked(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) { + val1, val2 := (evictableRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), EvictableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *evictableRangeSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *evictableRangeSet) Isolate(seg evictableRangeIterator, r EvictableRange) evictableRangeIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *evictableRangeSet) ApplyContiguous(r EvictableRange, fn func(seg evictableRangeIterator)) evictableRangeGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return evictableRangeGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return evictableRangeGapIterator{} + } + } +} + +// +stateify savable +type evictableRangenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *evictableRangenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap evictableRangedynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [evictableRangemaxDegree - 1]EvictableRange + values [evictableRangemaxDegree - 1]evictableRangeSetValue + children [evictableRangemaxDegree]*evictableRangenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *evictableRangenode) firstSegment() evictableRangeIterator { + for n.hasChildren { + n = n.children[0] + } + return evictableRangeIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *evictableRangenode) lastSegment() evictableRangeIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return evictableRangeIterator{n, n.nrSegments - 1} +} + +func (n *evictableRangenode) prevSibling() *evictableRangenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *evictableRangenode) nextSibling() *evictableRangenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *evictableRangenode) rebalanceBeforeInsert(gap evictableRangeGapIterator) evictableRangeGapIterator { + if n.nrSegments < evictableRangemaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:evictableRangeminDegree-1], n.keys[:evictableRangeminDegree-1]) + copy(left.values[:evictableRangeminDegree-1], n.values[:evictableRangeminDegree-1]) + copy(right.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:]) + copy(right.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:]) + n.keys[0], n.values[0] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1] + evictableRangezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:evictableRangeminDegree], n.children[:evictableRangeminDegree]) + copy(right.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:]) + evictableRangezeroNodeSlice(n.children[2:]) + for i := 0; i < evictableRangeminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if evictableRangetrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < evictableRangeminDegree { + return evictableRangeGapIterator{left, gap.index} + } + return evictableRangeGapIterator{right, gap.index - evictableRangeminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:]) + copy(sibling.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:]) + evictableRangezeroValueSlice(n.values[evictableRangeminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:]) + evictableRangezeroNodeSlice(n.children[evictableRangeminDegree:]) + for i := 0; i < evictableRangeminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = evictableRangeminDegree - 1 + + if evictableRangetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < evictableRangeminDegree { + return gap + } + return evictableRangeGapIterator{sibling, gap.index - evictableRangeminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *evictableRangenode) rebalanceAfterRemove(gap evictableRangeGapIterator) evictableRangeGapIterator { + for { + if n.nrSegments >= evictableRangeminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if evictableRangetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return evictableRangeGapIterator{n, 0} + } + if gap.node == n { + return evictableRangeGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if evictableRangetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return evictableRangeGapIterator{n, n.nrSegments} + } + return evictableRangeGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return evictableRangeGapIterator{p, gap.index} + } + if gap.node == right { + return evictableRangeGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *evictableRangenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = evictableRangeGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + evictableRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if evictableRangetrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *evictableRangenode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *evictableRangenode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *evictableRangenode) calculateMaxGapLeaf() uint64 { + max := evictableRangeGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (evictableRangeGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *evictableRangenode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *evictableRangenode) searchFirstLargeEnoughGap(minSize uint64) evictableRangeGapIterator { + if n.maxGap.Get() < minSize { + return evictableRangeGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := evictableRangeGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *evictableRangenode) searchLastLargeEnoughGap(minSize uint64) evictableRangeGapIterator { + if n.maxGap.Get() < minSize { + return evictableRangeGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := evictableRangeGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type evictableRangeIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *evictableRangenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg evictableRangeIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg evictableRangeIterator) Range() EvictableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg evictableRangeIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg evictableRangeIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg evictableRangeIterator) SetRangeUnchecked(r EvictableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetRange(r EvictableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg evictableRangeIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg evictableRangeIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg evictableRangeIterator) Value() evictableRangeSetValue { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg evictableRangeIterator) ValuePtr() *evictableRangeSetValue { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg evictableRangeIterator) SetValue(val evictableRangeSetValue) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg evictableRangeIterator) PrevSegment() evictableRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return evictableRangeIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return evictableRangeIterator{} + } + return evictableRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg evictableRangeIterator) NextSegment() evictableRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return evictableRangeIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return evictableRangeIterator{} + } + return evictableRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg evictableRangeIterator) PrevGap() evictableRangeGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return evictableRangeGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg evictableRangeIterator) NextGap() evictableRangeGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return evictableRangeGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg evictableRangeIterator) PrevNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return evictableRangeIterator{}, gap + } + return gap.PrevSegment(), evictableRangeGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg evictableRangeIterator) NextNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return evictableRangeIterator{}, gap + } + return gap.NextSegment(), evictableRangeGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type evictableRangeGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *evictableRangenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap evictableRangeGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap evictableRangeGapIterator) Range() EvictableRange { + return EvictableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap evictableRangeGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return evictableRangeSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap evictableRangeGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return evictableRangeSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap evictableRangeGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap evictableRangeGapIterator) PrevSegment() evictableRangeIterator { + return evictableRangesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap evictableRangeGapIterator) NextSegment() evictableRangeIterator { + return evictableRangesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap evictableRangeGapIterator) PrevGap() evictableRangeGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return evictableRangeGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap evictableRangeGapIterator) NextGap() evictableRangeGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return evictableRangeGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap evictableRangeGapIterator) NextLargeEnoughGap(minSize uint64) evictableRangeGapIterator { + if evictableRangetrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap evictableRangeGapIterator) nextLargeEnoughGapHelper(minSize uint64) evictableRangeGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return evictableRangeGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap evictableRangeGapIterator) PrevLargeEnoughGap(minSize uint64) evictableRangeGapIterator { + if evictableRangetrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap evictableRangeGapIterator) prevLargeEnoughGapHelper(minSize uint64) evictableRangeGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return evictableRangeGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func evictableRangesegmentBeforePosition(n *evictableRangenode, i int) evictableRangeIterator { + for i == 0 { + if n.parent == nil { + return evictableRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return evictableRangeIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func evictableRangesegmentAfterPosition(n *evictableRangenode, i int) evictableRangeIterator { + for i == n.nrSegments { + if n.parent == nil { + return evictableRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return evictableRangeIterator{n, i} +} + +func evictableRangezeroValueSlice(slice []evictableRangeSetValue) { + + for i := range slice { + evictableRangeSetFunctions{}.ClearValue(&slice[i]) + } +} + +func evictableRangezeroNodeSlice(slice []*evictableRangenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *evictableRangeSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *evictableRangenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *evictableRangenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if evictableRangetrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type evictableRangeSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []evictableRangeSetValue +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *evictableRangeSet) ExportSortedSlices() *evictableRangeSegmentDataSlices { + var sds evictableRangeSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *evictableRangeSet) ImportSortedSlices(sds *evictableRangeSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := EvictableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *evictableRangeSet) segmentTestCheck(expectedSegments int, segFunc func(int, EvictableRange, evictableRangeSetValue) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *evictableRangeSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *evictableRangeSet) saveRoot() *evictableRangeSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *evictableRangeSet) loadRoot(sds *evictableRangeSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/pgalloc_state_autogen.go b/pkg/sentry/pgalloc/pgalloc_state_autogen.go new file mode 100644 index 000000000..e1df238fb --- /dev/null +++ b/pkg/sentry/pgalloc/pgalloc_state_autogen.go @@ -0,0 +1,392 @@ +// automatically generated by stateify. + +package pgalloc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *EvictableRange) StateTypeName() string { + return "pkg/sentry/pgalloc.EvictableRange" +} + +func (r *EvictableRange) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (r *EvictableRange) beforeSave() {} + +// +checklocksignore +func (r *EvictableRange) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) +} + +func (r *EvictableRange) afterLoad() {} + +// +checklocksignore +func (r *EvictableRange) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) +} + +func (s *evictableRangeSet) StateTypeName() string { + return "pkg/sentry/pgalloc.evictableRangeSet" +} + +func (s *evictableRangeSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *evictableRangeSet) beforeSave() {} + +// +checklocksignore +func (s *evictableRangeSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *evictableRangeSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *evictableRangeSet) afterLoad() {} + +// +checklocksignore +func (s *evictableRangeSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*evictableRangeSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*evictableRangeSegmentDataSlices)) }) +} + +func (n *evictableRangenode) StateTypeName() string { + return "pkg/sentry/pgalloc.evictableRangenode" +} + +func (n *evictableRangenode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *evictableRangenode) beforeSave() {} + +// +checklocksignore +func (n *evictableRangenode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *evictableRangenode) afterLoad() {} + +// +checklocksignore +func (n *evictableRangenode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (e *evictableRangeSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/pgalloc.evictableRangeSegmentDataSlices" +} + +func (e *evictableRangeSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (e *evictableRangeSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (e *evictableRangeSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Start) + stateSinkObject.Save(1, &e.End) + stateSinkObject.Save(2, &e.Values) +} + +func (e *evictableRangeSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (e *evictableRangeSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Start) + stateSourceObject.Load(1, &e.End) + stateSourceObject.Load(2, &e.Values) +} + +func (u *usageInfo) StateTypeName() string { + return "pkg/sentry/pgalloc.usageInfo" +} + +func (u *usageInfo) StateFields() []string { + return []string{ + "kind", + "knownCommitted", + "refs", + } +} + +func (u *usageInfo) beforeSave() {} + +// +checklocksignore +func (u *usageInfo) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.kind) + stateSinkObject.Save(1, &u.knownCommitted) + stateSinkObject.Save(2, &u.refs) +} + +func (u *usageInfo) afterLoad() {} + +// +checklocksignore +func (u *usageInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.kind) + stateSourceObject.Load(1, &u.knownCommitted) + stateSourceObject.Load(2, &u.refs) +} + +func (s *reclaimSet) StateTypeName() string { + return "pkg/sentry/pgalloc.reclaimSet" +} + +func (s *reclaimSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *reclaimSet) beforeSave() {} + +// +checklocksignore +func (s *reclaimSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *reclaimSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *reclaimSet) afterLoad() {} + +// +checklocksignore +func (s *reclaimSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*reclaimSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*reclaimSegmentDataSlices)) }) +} + +func (n *reclaimnode) StateTypeName() string { + return "pkg/sentry/pgalloc.reclaimnode" +} + +func (n *reclaimnode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *reclaimnode) beforeSave() {} + +// +checklocksignore +func (n *reclaimnode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *reclaimnode) afterLoad() {} + +// +checklocksignore +func (n *reclaimnode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (r *reclaimSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/pgalloc.reclaimSegmentDataSlices" +} + +func (r *reclaimSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (r *reclaimSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (r *reclaimSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) + stateSinkObject.Save(2, &r.Values) +} + +func (r *reclaimSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (r *reclaimSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) + stateSourceObject.Load(2, &r.Values) +} + +func (s *usageSet) StateTypeName() string { + return "pkg/sentry/pgalloc.usageSet" +} + +func (s *usageSet) StateFields() []string { + return []string{ + "root", + } +} + +func (s *usageSet) beforeSave() {} + +// +checklocksignore +func (s *usageSet) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var rootValue *usageSegmentDataSlices + rootValue = s.saveRoot() + stateSinkObject.SaveValue(0, rootValue) +} + +func (s *usageSet) afterLoad() {} + +// +checklocksignore +func (s *usageSet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadValue(0, new(*usageSegmentDataSlices), func(y interface{}) { s.loadRoot(y.(*usageSegmentDataSlices)) }) +} + +func (n *usagenode) StateTypeName() string { + return "pkg/sentry/pgalloc.usagenode" +} + +func (n *usagenode) StateFields() []string { + return []string{ + "nrSegments", + "parent", + "parentIndex", + "hasChildren", + "maxGap", + "keys", + "values", + "children", + } +} + +func (n *usagenode) beforeSave() {} + +// +checklocksignore +func (n *usagenode) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.nrSegments) + stateSinkObject.Save(1, &n.parent) + stateSinkObject.Save(2, &n.parentIndex) + stateSinkObject.Save(3, &n.hasChildren) + stateSinkObject.Save(4, &n.maxGap) + stateSinkObject.Save(5, &n.keys) + stateSinkObject.Save(6, &n.values) + stateSinkObject.Save(7, &n.children) +} + +func (n *usagenode) afterLoad() {} + +// +checklocksignore +func (n *usagenode) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.nrSegments) + stateSourceObject.Load(1, &n.parent) + stateSourceObject.Load(2, &n.parentIndex) + stateSourceObject.Load(3, &n.hasChildren) + stateSourceObject.Load(4, &n.maxGap) + stateSourceObject.Load(5, &n.keys) + stateSourceObject.Load(6, &n.values) + stateSourceObject.Load(7, &n.children) +} + +func (u *usageSegmentDataSlices) StateTypeName() string { + return "pkg/sentry/pgalloc.usageSegmentDataSlices" +} + +func (u *usageSegmentDataSlices) StateFields() []string { + return []string{ + "Start", + "End", + "Values", + } +} + +func (u *usageSegmentDataSlices) beforeSave() {} + +// +checklocksignore +func (u *usageSegmentDataSlices) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.Start) + stateSinkObject.Save(1, &u.End) + stateSinkObject.Save(2, &u.Values) +} + +func (u *usageSegmentDataSlices) afterLoad() {} + +// +checklocksignore +func (u *usageSegmentDataSlices) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.Start) + stateSourceObject.Load(1, &u.End) + stateSourceObject.Load(2, &u.Values) +} + +func init() { + state.Register((*EvictableRange)(nil)) + state.Register((*evictableRangeSet)(nil)) + state.Register((*evictableRangenode)(nil)) + state.Register((*evictableRangeSegmentDataSlices)(nil)) + state.Register((*usageInfo)(nil)) + state.Register((*reclaimSet)(nil)) + state.Register((*reclaimnode)(nil)) + state.Register((*reclaimSegmentDataSlices)(nil)) + state.Register((*usageSet)(nil)) + state.Register((*usagenode)(nil)) + state.Register((*usageSegmentDataSlices)(nil)) +} diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go deleted file mode 100644 index 8d2b7eb5e..000000000 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgalloc - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/hostarch" -) - -const ( - page = hostarch.PageSize - hugepage = hostarch.HugePageSize - topPage = (1 << 63) - page -) - -func TestFindUnallocatedRange(t *testing.T) { - for _, test := range []struct { - desc string - usage *usageSegmentDataSlices - fileSize int64 - length uint64 - alignment uint64 - start uint64 - expectFail bool - }{ - { - desc: "Initial allocation succeeds", - usage: &usageSegmentDataSlices{}, - length: page, - alignment: page, - start: chunkSize - page, // Grows by chunkSize, allocate down. - }, - { - desc: "Allocation finds empty space at start of file", - usage: &usageSegmentDataSlices{ - Start: []uint64{page}, - End: []uint64{2 * page}, - Values: []usageInfo{{refs: 1}}, - }, - fileSize: 2 * page, - length: page, - alignment: page, - start: 0, - }, - { - desc: "Allocation finds empty space at end of file", - usage: &usageSegmentDataSlices{ - Start: []uint64{0}, - End: []uint64{page}, - Values: []usageInfo{{refs: 1}}, - }, - fileSize: 2 * page, - length: page, - alignment: page, - start: page, - }, - { - desc: "In-use frames are not allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, page}, - End: []uint64{page, 2 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - fileSize: 2 * page, - length: page, - alignment: page, - start: 3 * page, // Double fileSize, allocate top-down. - }, - { - desc: "Reclaimable frames are not allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, page, 2 * page}, - End: []uint64{page, 2 * page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}}, - }, - fileSize: 3 * page, - length: page, - alignment: page, - start: 5 * page, // Double fileSize, grow down. - }, - { - desc: "Gaps between in-use frames are allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2 * page}, - End: []uint64{page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - fileSize: 3 * page, - length: page, - alignment: page, - start: page, - }, - { - desc: "Inadequately-sized gaps are rejected", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2 * page}, - End: []uint64{page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - fileSize: 3 * page, - length: 2 * page, - alignment: page, - start: 4 * page, // Double fileSize, grow down. - }, - { - desc: "Alignment is honored at end of file", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, hugepage + page}, - // Hugepage-sized gap here that shouldn't be allocated from - // since it's incorrectly aligned. - End: []uint64{page, hugepage + 2*page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - fileSize: hugepage + 2*page, - length: hugepage, - alignment: hugepage, - start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down. - }, - { - desc: "Alignment is honored before end of file", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2*hugepage + page}, - // Page will need to be shifted down from top. - End: []uint64{page, 2*hugepage + 2*page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - fileSize: 2*hugepage + 2*page, - length: hugepage, - alignment: hugepage, - start: hugepage, - }, - { - desc: "Allocation doubles file size more than once if necessary", - usage: &usageSegmentDataSlices{}, - fileSize: page, - length: 4 * page, - alignment: page, - start: 0, - }, - { - desc: "Allocations are compact if possible", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, 3 * page}, - End: []uint64{2 * page, 4 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - fileSize: 4 * page, - length: page, - alignment: page, - start: 2 * page, - }, - { - desc: "Top-down allocation within one gap", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, 4 * page, 7 * page}, - End: []uint64{2 * page, 5 * page, 8 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}}, - }, - fileSize: 8 * page, - length: page, - alignment: page, - start: 6 * page, - }, - { - desc: "Top-down allocation between multiple gaps", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, 3 * page, 5 * page}, - End: []uint64{2 * page, 4 * page, 6 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}}, - }, - fileSize: 6 * page, - length: page, - alignment: page, - start: 4 * page, - }, - { - desc: "Top-down allocation with large top gap", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, 3 * page}, - End: []uint64{2 * page, 4 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - fileSize: 8 * page, - length: page, - alignment: page, - start: 7 * page, - }, - { - desc: "Gaps found with possible overflow", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, topPage - page}, - End: []uint64{2 * page, topPage}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - fileSize: topPage, - length: page, - alignment: page, - start: topPage - 2*page, - }, - { - desc: "Overflow detected", - usage: &usageSegmentDataSlices{ - Start: []uint64{page}, - End: []uint64{topPage}, - Values: []usageInfo{{refs: 1}}, - }, - fileSize: topPage, - length: 2 * page, - alignment: page, - expectFail: true, - }, - } { - t.Run(test.desc, func(t *testing.T) { - var usage usageSet - if err := usage.ImportSortedSlices(test.usage); err != nil { - t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err) - } - fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment) - if !test.expectFail && !ok { - t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) - } - if test.expectFail && ok { - t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) - } - if ok && fr.Start != test.start { - t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) - } - if ok && fr.End != test.start+test.length { - t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length) - } - }) - } -} diff --git a/pkg/sentry/pgalloc/pgalloc_unsafe_state_autogen.go b/pkg/sentry/pgalloc/pgalloc_unsafe_state_autogen.go new file mode 100644 index 000000000..87c214008 --- /dev/null +++ b/pkg/sentry/pgalloc/pgalloc_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pgalloc diff --git a/pkg/segment/set.go b/pkg/sentry/pgalloc/reclaim_set.go index fae6c363d..737f38776 100644 --- a/pkg/segment/set.go +++ b/pkg/sentry/pgalloc/reclaim_set.go @@ -1,41 +1,14 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package segment provides tools for working with collections of segments. A -// segment is a key-value mapping, where the key is a non-empty contiguous -// range of values of type Key, and the value is a single value of type Value. -// -// Clients using this package must use the go_template_instance rule in -// tools/go_generics/defs.bzl to create an instantiation of this -// template package, providing types to use in place of Key, Range, Value, and -// Functions. See pkg/segment/test/BUILD for a usage example. -package segment +package pgalloc + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) import ( "bytes" "fmt" ) -// Key is a required type parameter that must be an integral type. -type Key uint64 - -// Range is a required type parameter equivalent to Range<Key>. -type Range interface{} - -// Value is a required type parameter. -type Value interface{} - // trackGaps is an optional parameter. // // If trackGaps is 1, the Set will track maximum gap size recursively, @@ -43,59 +16,27 @@ type Value interface{} // case, Key must be an unsigned integer. // // trackGaps must be 0 or 1. -const trackGaps = 0 +const reclaimtrackGaps = 0 -var _ = uint8(trackGaps << 7) // Will fail if not zero or one. +var _ = uint8(reclaimtrackGaps << 7) // Will fail if not zero or one. // dynamicGap is a type that disappears if trackGaps is 0. -type dynamicGap [trackGaps]Key +type reclaimdynamicGap [reclaimtrackGaps]uint64 // Get returns the value of the gap. // // Precondition: trackGaps must be non-zero. -func (d *dynamicGap) Get() Key { +func (d *reclaimdynamicGap) Get() uint64 { return d[:][0] } // Set sets the value of the gap. // // Precondition: trackGaps must be non-zero. -func (d *dynamicGap) Set(v Key) { +func (d *reclaimdynamicGap) Set(v uint64) { d[:][0] = v } -// Functions is a required type parameter that must be a struct implementing -// the methods defined by Functions. -type Functions interface { - // MinKey returns the minimum allowed key. - MinKey() Key - - // MaxKey returns the maximum allowed key + 1. - MaxKey() Key - - // ClearValue deinitializes the given value. (For example, if Value is a - // pointer or interface type, ClearValue should set it to nil.) - ClearValue(*Value) - - // Merge attempts to merge the values corresponding to two consecutive - // segments. If successful, Merge returns (merged value, true). Otherwise, - // it returns (unspecified, false). - // - // Preconditions: r1.End == r2.Start. - // - // Postconditions: If merging succeeds, val1 and val2 are invalidated. - Merge(r1 Range, val1 Value, r2 Range, val2 Value) (Value, bool) - - // Split splits a segment's value at a key within its range, such that the - // first returned value corresponds to the range [r.Start, split) and the - // second returned value corresponds to the range [split, r.End). - // - // Preconditions: r.Start < split < r.End. - // - // Postconditions: The original value val is invalidated. - Split(r Range, val Value, split Key) (Value, Value) -} - const ( // minDegree is the minimum degree of an internal node in a Set B-tree. // @@ -108,9 +49,9 @@ const ( // // Our implementation requires minDegree >= 3. Higher values of minDegree // usually improve performance, but increase memory usage for small sets. - minDegree = 3 + reclaimminDegree = 10 - maxDegree = 2 * minDegree + reclaimmaxDegree = 2 * reclaimminDegree ) // A Set is a mapping of segments with non-overlapping Range keys. The zero @@ -118,19 +59,19 @@ const ( // copyable. Set is thread-compatible. // // +stateify savable -type Set struct { - root node `state:".(*SegmentDataSlices)"` +type reclaimSet struct { + root reclaimnode `state:".(*reclaimSegmentDataSlices)"` } // IsEmpty returns true if the set contains no segments. -func (s *Set) IsEmpty() bool { +func (s *reclaimSet) IsEmpty() bool { return s.root.nrSegments == 0 } // IsEmptyRange returns true iff no segments in the set overlap the given // range. This is semantically equivalent to s.SpanRange(r) == 0, but may be // more efficient. -func (s *Set) IsEmptyRange(r Range) bool { +func (s *reclaimSet) IsEmptyRange(r __generics_imported0.FileRange) bool { switch { case r.Length() < 0: panic(fmt.Sprintf("invalid range %v", r)) @@ -145,8 +86,8 @@ func (s *Set) IsEmptyRange(r Range) bool { } // Span returns the total size of all segments in the set. -func (s *Set) Span() Key { - var sz Key +func (s *reclaimSet) Span() uint64 { + var sz uint64 for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { sz += seg.Range().Length() } @@ -155,14 +96,14 @@ func (s *Set) Span() Key { // SpanRange returns the total size of the intersection of segments in the set // with the given range. -func (s *Set) SpanRange(r Range) Key { +func (s *reclaimSet) SpanRange(r __generics_imported0.FileRange) uint64 { switch { case r.Length() < 0: panic(fmt.Sprintf("invalid range %v", r)) case r.Length() == 0: return 0 } - var sz Key + var sz uint64 for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { sz += seg.Range().Intersect(r).Length() } @@ -171,56 +112,55 @@ func (s *Set) SpanRange(r Range) Key { // FirstSegment returns the first segment in the set. If the set is empty, // FirstSegment returns a terminal iterator. -func (s *Set) FirstSegment() Iterator { +func (s *reclaimSet) FirstSegment() reclaimIterator { if s.root.nrSegments == 0 { - return Iterator{} + return reclaimIterator{} } return s.root.firstSegment() } // LastSegment returns the last segment in the set. If the set is empty, // LastSegment returns a terminal iterator. -func (s *Set) LastSegment() Iterator { +func (s *reclaimSet) LastSegment() reclaimIterator { if s.root.nrSegments == 0 { - return Iterator{} + return reclaimIterator{} } return s.root.lastSegment() } // FirstGap returns the first gap in the set. -func (s *Set) FirstGap() GapIterator { +func (s *reclaimSet) FirstGap() reclaimGapIterator { n := &s.root for n.hasChildren { n = n.children[0] } - return GapIterator{n, 0} + return reclaimGapIterator{n, 0} } // LastGap returns the last gap in the set. -func (s *Set) LastGap() GapIterator { +func (s *reclaimSet) LastGap() reclaimGapIterator { n := &s.root for n.hasChildren { n = n.children[n.nrSegments] } - return GapIterator{n, n.nrSegments} + return reclaimGapIterator{n, n.nrSegments} } // Find returns the segment or gap whose range contains the given key. If a // segment is found, the returned Iterator is non-terminal and the // returned GapIterator is terminal. Otherwise, the returned Iterator is // terminal and the returned GapIterator is non-terminal. -func (s *Set) Find(key Key) (Iterator, GapIterator) { +func (s *reclaimSet) Find(key uint64) (reclaimIterator, reclaimGapIterator) { n := &s.root for { - // Binary search invariant: the correct value of i lies within [lower, - // upper]. + lower := 0 upper := n.nrSegments for lower < upper { i := lower + (upper-lower)/2 if r := n.keys[i]; key < r.End { if key >= r.Start { - return Iterator{n, i}, GapIterator{} + return reclaimIterator{n, i}, reclaimGapIterator{} } upper = i } else { @@ -229,7 +169,7 @@ func (s *Set) Find(key Key) (Iterator, GapIterator) { } i := lower if !n.hasChildren { - return Iterator{}, GapIterator{n, i} + return reclaimIterator{}, reclaimGapIterator{n, i} } n = n.children[i] } @@ -237,7 +177,7 @@ func (s *Set) Find(key Key) (Iterator, GapIterator) { // FindSegment returns the segment whose range contains the given key. If no // such segment exists, FindSegment returns a terminal iterator. -func (s *Set) FindSegment(key Key) Iterator { +func (s *reclaimSet) FindSegment(key uint64) reclaimIterator { seg, _ := s.Find(key) return seg } @@ -245,7 +185,7 @@ func (s *Set) FindSegment(key Key) Iterator { // LowerBoundSegment returns the segment with the lowest range that contains a // key greater than or equal to min. If no such segment exists, // LowerBoundSegment returns a terminal iterator. -func (s *Set) LowerBoundSegment(min Key) Iterator { +func (s *reclaimSet) LowerBoundSegment(min uint64) reclaimIterator { seg, gap := s.Find(min) if seg.Ok() { return seg @@ -256,7 +196,7 @@ func (s *Set) LowerBoundSegment(min Key) Iterator { // UpperBoundSegment returns the segment with the highest range that contains a // key less than or equal to max. If no such segment exists, UpperBoundSegment // returns a terminal iterator. -func (s *Set) UpperBoundSegment(max Key) Iterator { +func (s *reclaimSet) UpperBoundSegment(max uint64) reclaimIterator { seg, gap := s.Find(max) if seg.Ok() { return seg @@ -267,14 +207,14 @@ func (s *Set) UpperBoundSegment(max Key) Iterator { // FindGap returns the gap containing the given key. If no such gap exists // (i.e. the set contains a segment containing that key), FindGap returns a // terminal iterator. -func (s *Set) FindGap(key Key) GapIterator { +func (s *reclaimSet) FindGap(key uint64) reclaimGapIterator { _, gap := s.Find(key) return gap } // LowerBoundGap returns the gap with the lowest range that is greater than or // equal to min. -func (s *Set) LowerBoundGap(min Key) GapIterator { +func (s *reclaimSet) LowerBoundGap(min uint64) reclaimGapIterator { seg, gap := s.Find(min) if gap.Ok() { return gap @@ -284,7 +224,7 @@ func (s *Set) LowerBoundGap(min Key) GapIterator { // UpperBoundGap returns the gap with the highest range that is less than or // equal to max. -func (s *Set) UpperBoundGap(max Key) GapIterator { +func (s *reclaimSet) UpperBoundGap(max uint64) reclaimGapIterator { seg, gap := s.Find(max) if gap.Ok() { return gap @@ -296,7 +236,7 @@ func (s *Set) UpperBoundGap(max Key) GapIterator { // segment can be merged with adjacent segments, Add will do so. If the new // segment would overlap an existing segment, Add returns false. If Add // succeeds, all existing iterators are invalidated. -func (s *Set) Add(r Range, val Value) bool { +func (s *reclaimSet) Add(r __generics_imported0.FileRange, val reclaimSetValue) bool { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -315,7 +255,7 @@ func (s *Set) Add(r Range, val Value) bool { // If it would overlap an existing segment, AddWithoutMerging does nothing and // returns false. If AddWithoutMerging succeeds, all existing iterators are // invalidated. -func (s *Set) AddWithoutMerging(r Range, val Value) bool { +func (s *reclaimSet) AddWithoutMerging(r __generics_imported0.FileRange, val reclaimSetValue) bool { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -342,7 +282,7 @@ func (s *Set) AddWithoutMerging(r Range, val Value) bool { // Merge, but may be more efficient. Note that there is no unchecked variant of // Insert since Insert must retrieve and inspect gap's predecessor and // successor segments regardless. -func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { +func (s *reclaimSet) Insert(gap reclaimGapIterator, r __generics_imported0.FileRange, val reclaimSetValue) reclaimIterator { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -354,8 +294,8 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) } if prev.Ok() && prev.End() == r.Start { - if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { - shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + if mval, ok := (reclaimSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := reclaimtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() prev.SetEndUnchecked(r.End) prev.SetValue(mval) if shrinkMaxGap { @@ -363,7 +303,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if next.Ok() && next.Start() == r.End { val = mval - if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + if mval, ok := (reclaimSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { prev.SetEndUnchecked(next.End()) prev.SetValue(mval) return s.Remove(next).PrevSegment() @@ -373,8 +313,8 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } } if next.Ok() && next.Start() == r.End { - if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { - shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + if mval, ok := (reclaimSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := reclaimtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() next.SetStartUnchecked(r.Start) next.SetValue(mval) if shrinkMaxGap { @@ -383,7 +323,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { return next } } - // InsertWithoutMergingUnchecked will maintain maxGap if necessary. + return s.InsertWithoutMergingUnchecked(gap, r, val) } @@ -393,7 +333,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { // // If the gap cannot accommodate the segment, or if r is invalid, // InsertWithoutMerging panics. -func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator { +func (s *reclaimSet) InsertWithoutMerging(gap reclaimGapIterator, r __generics_imported0.FileRange, val reclaimSetValue) reclaimIterator { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -410,9 +350,9 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // Preconditions: // * r.Start >= gap.Start(). // * r.End <= gap.End(). -func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { +func (s *reclaimSet) InsertWithoutMergingUnchecked(gap reclaimGapIterator, r __generics_imported0.FileRange, val reclaimSetValue) reclaimIterator { gap = gap.node.rebalanceBeforeInsert(gap) - splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + splitMaxGap := reclaimtrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) gap.node.keys[gap.index] = r @@ -421,56 +361,46 @@ func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) if splitMaxGap { gap.node.updateMaxGapLeaf() } - return Iterator{gap.node, gap.index} + return reclaimIterator{gap.node, gap.index} } // Remove removes the given segment and returns an iterator to the vacated gap. // All existing iterators (including seg, but not including the returned // iterator) are invalidated. -func (s *Set) Remove(seg Iterator) GapIterator { - // We only want to remove directly from a leaf node. +func (s *reclaimSet) Remove(seg reclaimIterator) reclaimGapIterator { + if seg.node.hasChildren { - // Since seg.node has children, the removed segment must have a - // predecessor (at the end of the rightmost leaf of its left child - // subtree). Move the contents of that predecessor into the removed - // segment's position, and remove that predecessor instead. (We choose - // to steal the predecessor rather than the successor because removing - // from the end of a leaf node doesn't involve any copying unless - // merging is required.) + victim := seg.PrevSegment() - // This must be unchecked since until victim is removed, seg and victim - // overlap. + seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) - // Need to update the nextAdjacentNode's maxGap because the gap in between - // must have been modified by updating seg.Range() to victim.Range(). - // seg.NextSegment() must exist since the last segment can't be in a - // non-leaf node. + nextAdjacentNode := seg.NextSegment().node - if trackGaps != 0 { + if reclaimtrackGaps != 0 { nextAdjacentNode.updateMaxGapLeaf() } return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) - Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + reclaimSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- - if trackGaps != 0 { + if reclaimtrackGaps != 0 { seg.node.updateMaxGapLeaf() } - return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) + return seg.node.rebalanceAfterRemove(reclaimGapIterator{seg.node, seg.index}) } // RemoveAll removes all segments from the set. All existing iterators are // invalidated. -func (s *Set) RemoveAll() { - s.root = node{} +func (s *reclaimSet) RemoveAll() { + s.root = reclaimnode{} } // RemoveRange removes all segments in the given range. An iterator to the // newly formed gap is returned, and all existing iterators are invalidated. -func (s *Set) RemoveRange(r Range) GapIterator { +func (s *reclaimSet) RemoveRange(r __generics_imported0.FileRange) reclaimGapIterator { seg, gap := s.Find(r.Start) if seg.Ok() { seg = s.Isolate(seg, r) @@ -488,7 +418,7 @@ func (s *Set) RemoveRange(r Range) GapIterator { // invalidated. Otherwise, Merge returns a terminal iterator. // // If first is not the predecessor of second, Merge panics. -func (s *Set) Merge(first, second Iterator) Iterator { +func (s *reclaimSet) Merge(first, second reclaimIterator) reclaimIterator { if first.NextSegment() != second { panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) } @@ -502,23 +432,22 @@ func (s *Set) Merge(first, second Iterator) Iterator { // // Precondition: first is the predecessor of second: first.NextSegment() == // second, first == second.PrevSegment(). -func (s *Set) MergeUnchecked(first, second Iterator) Iterator { +func (s *reclaimSet) MergeUnchecked(first, second reclaimIterator) reclaimIterator { if first.End() == second.Start() { - if mval, ok := (Functions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { - // N.B. This must be unchecked because until s.Remove(second), first - // overlaps second. + if mval, ok := (reclaimSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + first.SetEndUnchecked(second.End()) first.SetValue(mval) - // Remove will handle the maxGap update if necessary. + return s.Remove(second).PrevSegment() } } - return Iterator{} + return reclaimIterator{} } // MergeAll attempts to merge all adjacent segments in the set. All existing // iterators are invalidated. -func (s *Set) MergeAll() { +func (s *reclaimSet) MergeAll() { seg := s.FirstSegment() if !seg.Ok() { return @@ -535,7 +464,7 @@ func (s *Set) MergeAll() { // MergeRange attempts to merge all adjacent segments that contain a key in the // specific range. All existing iterators are invalidated. -func (s *Set) MergeRange(r Range) { +func (s *reclaimSet) MergeRange(r __generics_imported0.FileRange) { seg := s.LowerBoundSegment(r.Start) if !seg.Ok() { return @@ -552,7 +481,7 @@ func (s *Set) MergeRange(r Range) { // MergeAdjacent attempts to merge the segment containing r.Start with its // predecessor, and the segment containing r.End-1 with its successor. -func (s *Set) MergeAdjacent(r Range) { +func (s *reclaimSet) MergeAdjacent(r __generics_imported0.FileRange) { first := s.FindSegment(r.Start) if first.Ok() { if prev := first.PrevSegment(); prev.Ok() { @@ -575,7 +504,7 @@ func (s *Set) MergeAdjacent(r Range) { // end of the segment's range, so splitting would produce a segment with zero // length, or because split falls outside the segment's range altogether), // Split panics. -func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) { +func (s *reclaimSet) Split(seg reclaimIterator, split uint64) (reclaimIterator, reclaimIterator) { if !seg.Range().CanSplitAt(split) { panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) } @@ -587,20 +516,20 @@ func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) { // seg, but not including the returned iterators) are invalidated. // // Preconditions: seg.Start() < key < seg.End(). -func (s *Set) SplitUnchecked(seg Iterator, split Key) (Iterator, Iterator) { - val1, val2 := (Functions{}).Split(seg.Range(), seg.Value(), split) +func (s *reclaimSet) SplitUnchecked(seg reclaimIterator, split uint64) (reclaimIterator, reclaimIterator) { + val1, val2 := (reclaimSetFunctions{}).Split(seg.Range(), seg.Value(), split) end2 := seg.End() seg.SetEndUnchecked(split) seg.SetValue(val1) - seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), Range{split, end2}, val2) - // seg may now be invalid due to the Insert. + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + return seg2.PrevSegment(), seg2 } // SplitAt splits the segment straddling split, if one exists. SplitAt returns // true if a segment was split and false otherwise. If SplitAt splits a // segment, all existing iterators are invalidated. -func (s *Set) SplitAt(split Key) bool { +func (s *reclaimSet) SplitAt(split uint64) bool { if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { s.SplitUnchecked(seg, split) return true @@ -612,7 +541,7 @@ func (s *Set) SplitAt(split Key) bool { // splitting at r.Start and r.End if necessary, and returns an updated iterator // to the bounded segment. All existing iterators (including seg, but not // including the returned iterators) are invalidated. -func (s *Set) Isolate(seg Iterator, r Range) Iterator { +func (s *reclaimSet) Isolate(seg reclaimIterator, r __generics_imported0.FileRange) reclaimIterator { if seg.Range().CanSplitAt(r.Start) { _, seg = s.SplitUnchecked(seg, r.Start) } @@ -629,7 +558,7 @@ func (s *Set) Isolate(seg Iterator, r Range) Iterator { // are invalidated. // // N.B. The Iterator must not be invalidated by the function. -func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { +func (s *reclaimSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg reclaimIterator)) reclaimGapIterator { seg, gap := s.Find(r.Start) if !seg.Ok() { return gap @@ -638,7 +567,7 @@ func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { seg = s.Isolate(seg, r) fn(seg) if seg.End() >= r.End { - return GapIterator{} + return reclaimGapIterator{} } gap = seg.NextGap() if !gap.IsEmpty() { @@ -646,15 +575,14 @@ func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { } seg = gap.NextSegment() if !seg.Ok() { - // This implies that the last segment extended all the - // way to the maximum value, since the gap was empty. - return GapIterator{} + + return reclaimGapIterator{} } } } // +stateify savable -type node struct { +type reclaimnode struct { // An internal binary tree node looks like: // // K @@ -676,7 +604,7 @@ type node struct { // parent is a pointer to this node's parent. If this node is root, parent // is nil. - parent *node + parent *reclaimnode // parentIndex is the index of this node in parent.children. parentIndex int @@ -690,43 +618,43 @@ type node struct { // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys // including the 0th and nrSegments-th gap possibly shared with its upper-level // nodes; if it's a non-leaf node, it's the max of all children's maxGap. - maxGap dynamicGap + maxGap reclaimdynamicGap // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). - keys [maxDegree - 1]Range - values [maxDegree - 1]Value - children [maxDegree]*node + keys [reclaimmaxDegree - 1]__generics_imported0.FileRange + values [reclaimmaxDegree - 1]reclaimSetValue + children [reclaimmaxDegree]*reclaimnode } // firstSegment returns the first segment in the subtree rooted by n. // // Preconditions: n.nrSegments != 0. -func (n *node) firstSegment() Iterator { +func (n *reclaimnode) firstSegment() reclaimIterator { for n.hasChildren { n = n.children[0] } - return Iterator{n, 0} + return reclaimIterator{n, 0} } // lastSegment returns the last segment in the subtree rooted by n. // // Preconditions: n.nrSegments != 0. -func (n *node) lastSegment() Iterator { +func (n *reclaimnode) lastSegment() reclaimIterator { for n.hasChildren { n = n.children[n.nrSegments] } - return Iterator{n, n.nrSegments - 1} + return reclaimIterator{n, n.nrSegments - 1} } -func (n *node) prevSibling() *node { +func (n *reclaimnode) prevSibling() *reclaimnode { if n.parent == nil || n.parentIndex == 0 { return nil } return n.parent.children[n.parentIndex-1] } -func (n *node) nextSibling() *node { +func (n *reclaimnode) nextSibling() *reclaimnode { if n.parent == nil || n.parentIndex == n.parent.nrSegments { return nil } @@ -736,40 +664,38 @@ func (n *node) nextSibling() *node { // rebalanceBeforeInsert splits n and its ancestors if they are full, as // required for insertion, and returns an updated iterator to the position // represented by gap. -func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { - if n.nrSegments < maxDegree-1 { +func (n *reclaimnode) rebalanceBeforeInsert(gap reclaimGapIterator) reclaimGapIterator { + if n.nrSegments < reclaimmaxDegree-1 { return gap } if n.parent != nil { gap = n.parent.rebalanceBeforeInsert(gap) } if n.parent == nil { - // n is root. Move all segments before and after n's median segment - // into new child nodes adjacent to the median segment, which is now - // the only segment in root. - left := &node{ - nrSegments: minDegree - 1, + + left := &reclaimnode{ + nrSegments: reclaimminDegree - 1, parent: n, parentIndex: 0, hasChildren: n.hasChildren, } - right := &node{ - nrSegments: minDegree - 1, + right := &reclaimnode{ + nrSegments: reclaimminDegree - 1, parent: n, parentIndex: 1, hasChildren: n.hasChildren, } - copy(left.keys[:minDegree-1], n.keys[:minDegree-1]) - copy(left.values[:minDegree-1], n.values[:minDegree-1]) - copy(right.keys[:minDegree-1], n.keys[minDegree:]) - copy(right.values[:minDegree-1], n.values[minDegree:]) - n.keys[0], n.values[0] = n.keys[minDegree-1], n.values[minDegree-1] - zeroValueSlice(n.values[1:]) + copy(left.keys[:reclaimminDegree-1], n.keys[:reclaimminDegree-1]) + copy(left.values[:reclaimminDegree-1], n.values[:reclaimminDegree-1]) + copy(right.keys[:reclaimminDegree-1], n.keys[reclaimminDegree:]) + copy(right.values[:reclaimminDegree-1], n.values[reclaimminDegree:]) + n.keys[0], n.values[0] = n.keys[reclaimminDegree-1], n.values[reclaimminDegree-1] + reclaimzeroValueSlice(n.values[1:]) if n.hasChildren { - copy(left.children[:minDegree], n.children[:minDegree]) - copy(right.children[:minDegree], n.children[minDegree:]) - zeroNodeSlice(n.children[2:]) - for i := 0; i < minDegree; i++ { + copy(left.children[:reclaimminDegree], n.children[:reclaimminDegree]) + copy(right.children[:reclaimminDegree], n.children[reclaimminDegree:]) + reclaimzeroNodeSlice(n.children[2:]) + for i := 0; i < reclaimminDegree; i++ { left.children[i].parent = left left.children[i].parentIndex = i right.children[i].parent = right @@ -780,66 +706,60 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { n.hasChildren = true n.children[0] = left n.children[1] = right - // In this case, n's maxGap won't violated as it's still the root, - // but the left and right children should be updated locally as they - // are newly split from n. - if trackGaps != 0 { + + if reclaimtrackGaps != 0 { left.updateMaxGapLocal() right.updateMaxGapLocal() } if gap.node != n { return gap } - if gap.index < minDegree { - return GapIterator{left, gap.index} + if gap.index < reclaimminDegree { + return reclaimGapIterator{left, gap.index} } - return GapIterator{right, gap.index - minDegree} + return reclaimGapIterator{right, gap.index - reclaimminDegree} } - // n is non-root. Move n's median segment into its parent node (which can't - // be full because we've already invoked n.parent.rebalanceBeforeInsert) - // and move all segments after n's median into a new sibling node (the - // median segment's right child subtree). + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) - n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[minDegree-1], n.values[minDegree-1] + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[reclaimminDegree-1], n.values[reclaimminDegree-1] copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { n.parent.children[i].parentIndex = i } - sibling := &node{ - nrSegments: minDegree - 1, + sibling := &reclaimnode{ + nrSegments: reclaimminDegree - 1, parent: n.parent, parentIndex: n.parentIndex + 1, hasChildren: n.hasChildren, } n.parent.children[n.parentIndex+1] = sibling n.parent.nrSegments++ - copy(sibling.keys[:minDegree-1], n.keys[minDegree:]) - copy(sibling.values[:minDegree-1], n.values[minDegree:]) - zeroValueSlice(n.values[minDegree-1:]) + copy(sibling.keys[:reclaimminDegree-1], n.keys[reclaimminDegree:]) + copy(sibling.values[:reclaimminDegree-1], n.values[reclaimminDegree:]) + reclaimzeroValueSlice(n.values[reclaimminDegree-1:]) if n.hasChildren { - copy(sibling.children[:minDegree], n.children[minDegree:]) - zeroNodeSlice(n.children[minDegree:]) - for i := 0; i < minDegree; i++ { + copy(sibling.children[:reclaimminDegree], n.children[reclaimminDegree:]) + reclaimzeroNodeSlice(n.children[reclaimminDegree:]) + for i := 0; i < reclaimminDegree; i++ { sibling.children[i].parent = sibling sibling.children[i].parentIndex = i } } - n.nrSegments = minDegree - 1 - // MaxGap of n's parent is not violated because the segments within is not changed. - // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n. - if trackGaps != 0 { + n.nrSegments = reclaimminDegree - 1 + + if reclaimtrackGaps != 0 { n.updateMaxGapLocal() sibling.updateMaxGapLocal() } - // gap.node can't be n.parent because gaps are always in leaf nodes. + if gap.node != n { return gap } - if gap.index < minDegree { + if gap.index < reclaimminDegree { return gap } - return GapIterator{sibling, gap.index - minDegree} + return reclaimGapIterator{sibling, gap.index - reclaimminDegree} } // rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient @@ -848,41 +768,24 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { // // Precondition: n is the only node in the tree that may currently violate a // B-tree invariant. -func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { +func (n *reclaimnode) rebalanceAfterRemove(gap reclaimGapIterator) reclaimGapIterator { for { - if n.nrSegments >= minDegree-1 { + if n.nrSegments >= reclaimminDegree-1 { return gap } if n.parent == nil { - // Root is allowed to be deficient. + return gap } - // There's one other thing we can do before resorting to unsplitting. - // If either sibling node has at least minDegree segments, rotate that - // sibling's closest segment through the segment in the parent that - // separates us. That is, given: - // - // ... D ... - // / \ - // ... B C] [E ... - // - // where the node containing E is deficient, end up with: - // - // ... C ... - // / \ - // ... B] [D E ... - // - // As in Set.Remove, prefer rotating from the end of the sibling to the - // left: by precondition, n.node has fewer segments (to memcpy) than - // the sibling does. - if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= minDegree { + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= reclaimminDegree { copy(n.keys[1:], n.keys[:n.nrSegments]) copy(n.values[1:], n.values[:n.nrSegments]) n.keys[0] = n.parent.keys[n.parentIndex-1] n.values[0] = n.parent.values[n.parentIndex-1] n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] - Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + reclaimSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) if n.hasChildren { copy(n.children[1:], n.children[:n.nrSegments+1]) n.children[0] = sibling.children[sibling.nrSegments] @@ -895,28 +798,27 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- - // n's parent's maxGap does not need to be updated as its content is unmodified. - // n and its sibling must be updated with (new) maxGap because of the shift of keys. - if trackGaps != 0 { + + if reclaimtrackGaps != 0 { n.updateMaxGapLocal() sibling.updateMaxGapLocal() } if gap.node == sibling && gap.index == sibling.nrSegments { - return GapIterator{n, 0} + return reclaimGapIterator{n, 0} } if gap.node == n { - return GapIterator{n, gap.index + 1} + return reclaimGapIterator{n, gap.index + 1} } return gap } - if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= minDegree { + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= reclaimminDegree { n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] n.values[n.nrSegments] = n.parent.values[n.parentIndex] n.parent.keys[n.parentIndex] = sibling.keys[0] n.parent.values[n.parentIndex] = sibling.values[0] copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) - Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + reclaimSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) if n.hasChildren { n.children[n.nrSegments+1] = sibling.children[0] copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) @@ -929,29 +831,23 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- - // n's parent's maxGap does not need to be updated as its content is unmodified. - // n and its sibling must be updated with (new) maxGap because of the shift of keys. - if trackGaps != 0 { + + if reclaimtrackGaps != 0 { n.updateMaxGapLocal() sibling.updateMaxGapLocal() } if gap.node == sibling { if gap.index == 0 { - return GapIterator{n, n.nrSegments} + return reclaimGapIterator{n, n.nrSegments} } - return GapIterator{sibling, gap.index - 1} + return reclaimGapIterator{sibling, gap.index - 1} } return gap } - // Otherwise, we must unsplit. + p := n.parent if p.nrSegments == 1 { - // Merge all segments in both n and its sibling back into n.parent. - // This is the reverse of the root splitting case in - // node.rebalanceBeforeInsert. (Because we require minDegree >= 3, - // only root can have 1 segment in this path, so this reduces the - // height of the tree by 1, without violating the constraint that - // all leaf nodes remain at the same depth.) + left, right := p.children[0], p.children[1] p.nrSegments = left.nrSegments + right.nrSegments + 1 p.hasChildren = left.hasChildren @@ -972,12 +868,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { p.children[0] = nil p.children[1] = nil } - // No need to update maxGap of p as its content is not changed. + if gap.node == left { - return GapIterator{p, gap.index} + return reclaimGapIterator{p, gap.index} } if gap.node == right { - return GapIterator{p, gap.index + left.nrSegments + 1} + return reclaimGapIterator{p, gap.index + left.nrSegments + 1} } return gap } @@ -985,7 +881,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { // two, into whichever of the two nodes comes first. This is the // reverse of the non-root splitting case in // node.rebalanceBeforeInsert. - var left, right *node + var left, right *reclaimnode if n.parentIndex > 0 { left = n.prevSibling() right = n @@ -993,10 +889,9 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { left = n right = n.nextSibling() } - // Fix up gap first since we need the old left.nrSegments, which - // merging will change. + if gap.node == right { - gap = GapIterator{left, gap.index + left.nrSegments + 1} + gap = reclaimGapIterator{left, gap.index + left.nrSegments + 1} } left.keys[left.nrSegments] = p.keys[left.parentIndex] left.values[left.nrSegments] = p.values[left.parentIndex] @@ -1012,19 +907,18 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { left.nrSegments += right.nrSegments + 1 copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) - Functions{}.ClearValue(&p.values[p.nrSegments-1]) + reclaimSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) for i := 0; i < p.nrSegments; i++ { p.children[i].parentIndex = i } p.children[p.nrSegments] = nil p.nrSegments-- - // Update maxGap of left locally, no need to change p and right because - // p's contents is not changed and right is already invalid. - if trackGaps != 0 { + + if reclaimtrackGaps != 0 { left.updateMaxGapLocal() } - // This process robs p of one segment, so recurse into rebalancing p. + n = p } } @@ -1033,46 +927,42 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { // necessary update. // // Preconditions: n must be a leaf node, trackGaps must be 1. -func (n *node) updateMaxGapLeaf() { +func (n *reclaimnode) updateMaxGapLeaf() { if n.hasChildren { panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) } max := n.calculateMaxGapLeaf() if max == n.maxGap.Get() { - // If new max equals the old maxGap, no update is needed. + return } oldMax := n.maxGap.Get() n.maxGap.Set(max) if max > oldMax { - // Grow ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { if p.maxGap.Get() >= max { - // p and its ancestors already contain an equal or larger gap. + break } - // Only if new maxGap is larger than parent's - // old maxGap, propagate this update to parent. + p.maxGap.Set(max) } return } - // Shrink ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { if p.maxGap.Get() > oldMax { - // p and its ancestors still contain a larger gap. + break } - // If new max is smaller than the old maxGap, and this gap used - // to be the maxGap of its parent, iterate parent's children - // and calculate parent's new maxGap.(It's probable that parent - // has two children with the old maxGap, but we need to check it anyway.) + parentNewMax := p.calculateMaxGapInternal() if p.maxGap.Get() == parentNewMax { - // p and its ancestors still contain a gap of at least equal size. + break } - // If p's new maxGap differs from the old one, propagate this update. + p.maxGap.Set(parentNewMax) } } @@ -1081,12 +971,12 @@ func (n *node) updateMaxGapLeaf() { // propagation to ancestor nodes. // // Precondition: trackGaps must be 1. -func (n *node) updateMaxGapLocal() { +func (n *reclaimnode) updateMaxGapLocal() { if !n.hasChildren { - // Leaf node iterates its gaps. + n.maxGap.Set(n.calculateMaxGapLeaf()) } else { - // Non-leaf node iterates its children. + n.maxGap.Set(n.calculateMaxGapInternal()) } } @@ -1095,10 +985,10 @@ func (n *node) updateMaxGapLocal() { // max. // // Preconditions: n must be a leaf node. -func (n *node) calculateMaxGapLeaf() Key { - max := GapIterator{n, 0}.Range().Length() +func (n *reclaimnode) calculateMaxGapLeaf() uint64 { + max := reclaimGapIterator{n, 0}.Range().Length() for i := 1; i <= n.nrSegments; i++ { - if current := (GapIterator{n, i}).Range().Length(); current > max { + if current := (reclaimGapIterator{n, i}).Range().Length(); current > max { max = current } } @@ -1109,7 +999,7 @@ func (n *node) calculateMaxGapLeaf() Key { // and calculate the max. // // Preconditions: n must be a non-leaf node. -func (n *node) calculateMaxGapInternal() Key { +func (n *reclaimnode) calculateMaxGapInternal() uint64 { max := n.children[0].maxGap.Get() for i := 1; i <= n.nrSegments; i++ { if current := n.children[i].maxGap.Get(); current > max { @@ -1121,9 +1011,9 @@ func (n *node) calculateMaxGapInternal() Key { // searchFirstLargeEnoughGap returns the first gap having at least minSize length // in the subtree rooted by n. If not found, return a terminal gap iterator. -func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { +func (n *reclaimnode) searchFirstLargeEnoughGap(minSize uint64) reclaimGapIterator { if n.maxGap.Get() < minSize { - return GapIterator{} + return reclaimGapIterator{} } if n.hasChildren { for i := 0; i <= n.nrSegments; i++ { @@ -1133,7 +1023,7 @@ func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { } } else { for i := 0; i <= n.nrSegments; i++ { - currentGap := GapIterator{n, i} + currentGap := reclaimGapIterator{n, i} if currentGap.Range().Length() >= minSize { return currentGap } @@ -1144,9 +1034,9 @@ func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { // searchLastLargeEnoughGap returns the last gap having at least minSize length // in the subtree rooted by n. If not found, return a terminal gap iterator. -func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { +func (n *reclaimnode) searchLastLargeEnoughGap(minSize uint64) reclaimGapIterator { if n.maxGap.Get() < minSize { - return GapIterator{} + return reclaimGapIterator{} } if n.hasChildren { for i := n.nrSegments; i >= 0; i-- { @@ -1156,7 +1046,7 @@ func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { } } else { for i := n.nrSegments; i >= 0; i-- { - currentGap := GapIterator{n, i} + currentGap := reclaimGapIterator{n, i} if currentGap.Range().Length() >= minSize { return currentGap } @@ -1177,10 +1067,10 @@ func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { // // Unless otherwise specified, any mutation of a set invalidates all existing // iterators into the set. -type Iterator struct { +type reclaimIterator struct { // node is the node containing the iterated segment. If the iterator is // terminal, node is nil. - node *node + node *reclaimnode // index is the index of the segment in node.keys/values. index int @@ -1188,24 +1078,24 @@ type Iterator struct { // Ok returns true if the iterator is not terminal. All other methods are only // valid for non-terminal iterators. -func (seg Iterator) Ok() bool { +func (seg reclaimIterator) Ok() bool { return seg.node != nil } // Range returns the iterated segment's range key. -func (seg Iterator) Range() Range { +func (seg reclaimIterator) Range() __generics_imported0.FileRange { return seg.node.keys[seg.index] } // Start is equivalent to Range().Start, but should be preferred if only the // start of the range is needed. -func (seg Iterator) Start() Key { +func (seg reclaimIterator) Start() uint64 { return seg.node.keys[seg.index].Start } // End is equivalent to Range().End, but should be preferred if only the end of // the range is needed. -func (seg Iterator) End() Key { +func (seg reclaimIterator) End() uint64 { return seg.node.keys[seg.index].End } @@ -1217,7 +1107,7 @@ func (seg Iterator) End() Key { // * The new range must not overlap an existing one: // * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). // * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). -func (seg Iterator) SetRangeUnchecked(r Range) { +func (seg reclaimIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { seg.node.keys[seg.index] = r } @@ -1225,7 +1115,7 @@ func (seg Iterator) SetRangeUnchecked(r Range) { // cause the iterated segment to overlap another segment, or if the new range // is invalid, SetRange panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetRange(r Range) { +func (seg reclaimIterator) SetRange(r __generics_imported0.FileRange) { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -1244,7 +1134,7 @@ func (seg Iterator) SetRange(r Range) { // Preconditions: The new start must be valid: // * start < seg.End() // * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). -func (seg Iterator) SetStartUnchecked(start Key) { +func (seg reclaimIterator) SetStartUnchecked(start uint64) { seg.node.keys[seg.index].Start = start } @@ -1252,7 +1142,7 @@ func (seg Iterator) SetStartUnchecked(start Key) { // cause the iterated segment to overlap another segment, or would result in an // invalid range, SetStart panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetStart(start Key) { +func (seg reclaimIterator) SetStart(start uint64) { if start >= seg.End() { panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) } @@ -1268,7 +1158,7 @@ func (seg Iterator) SetStart(start Key) { // Preconditions: The new end must be valid: // * end > seg.Start(). // * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). -func (seg Iterator) SetEndUnchecked(end Key) { +func (seg reclaimIterator) SetEndUnchecked(end uint64) { seg.node.keys[seg.index].End = end } @@ -1276,7 +1166,7 @@ func (seg Iterator) SetEndUnchecked(end Key) { // the iterated segment to overlap another segment, or would result in an // invalid range, SetEnd panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetEnd(end Key) { +func (seg reclaimIterator) SetEnd(end uint64) { if end <= seg.Start() { panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) } @@ -1287,69 +1177,68 @@ func (seg Iterator) SetEnd(end Key) { } // Value returns a copy of the iterated segment's value. -func (seg Iterator) Value() Value { +func (seg reclaimIterator) Value() reclaimSetValue { return seg.node.values[seg.index] } // ValuePtr returns a pointer to the iterated segment's value. The pointer is // invalidated if the iterator is invalidated. This operation does not // invalidate any iterators. -func (seg Iterator) ValuePtr() *Value { +func (seg reclaimIterator) ValuePtr() *reclaimSetValue { return &seg.node.values[seg.index] } // SetValue mutates the iterated segment's value. This operation does not // invalidate any iterators. -func (seg Iterator) SetValue(val Value) { +func (seg reclaimIterator) SetValue(val reclaimSetValue) { seg.node.values[seg.index] = val } // PrevSegment returns the iterated segment's predecessor. If there is no // preceding segment, PrevSegment returns a terminal iterator. -func (seg Iterator) PrevSegment() Iterator { +func (seg reclaimIterator) PrevSegment() reclaimIterator { if seg.node.hasChildren { return seg.node.children[seg.index].lastSegment() } if seg.index > 0 { - return Iterator{seg.node, seg.index - 1} + return reclaimIterator{seg.node, seg.index - 1} } if seg.node.parent == nil { - return Iterator{} + return reclaimIterator{} } - return segmentBeforePosition(seg.node.parent, seg.node.parentIndex) + return reclaimsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) } // NextSegment returns the iterated segment's successor. If there is no // succeeding segment, NextSegment returns a terminal iterator. -func (seg Iterator) NextSegment() Iterator { +func (seg reclaimIterator) NextSegment() reclaimIterator { if seg.node.hasChildren { return seg.node.children[seg.index+1].firstSegment() } if seg.index < seg.node.nrSegments-1 { - return Iterator{seg.node, seg.index + 1} + return reclaimIterator{seg.node, seg.index + 1} } if seg.node.parent == nil { - return Iterator{} + return reclaimIterator{} } - return segmentAfterPosition(seg.node.parent, seg.node.parentIndex) + return reclaimsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) } // PrevGap returns the gap immediately before the iterated segment. -func (seg Iterator) PrevGap() GapIterator { +func (seg reclaimIterator) PrevGap() reclaimGapIterator { if seg.node.hasChildren { - // Note that this isn't recursive because the last segment in a subtree - // must be in a leaf node. + return seg.node.children[seg.index].lastSegment().NextGap() } - return GapIterator{seg.node, seg.index} + return reclaimGapIterator{seg.node, seg.index} } // NextGap returns the gap immediately after the iterated segment. -func (seg Iterator) NextGap() GapIterator { +func (seg reclaimIterator) NextGap() reclaimGapIterator { if seg.node.hasChildren { return seg.node.children[seg.index+1].firstSegment().PrevGap() } - return GapIterator{seg.node, seg.index + 1} + return reclaimGapIterator{seg.node, seg.index + 1} } // PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, @@ -1357,12 +1246,12 @@ func (seg Iterator) NextGap() GapIterator { // Functions.MinKey(), PrevNonEmpty will return two terminal iterators. // Otherwise, exactly one of the iterators returned by PrevNonEmpty will be // non-terminal. -func (seg Iterator) PrevNonEmpty() (Iterator, GapIterator) { +func (seg reclaimIterator) PrevNonEmpty() (reclaimIterator, reclaimGapIterator) { gap := seg.PrevGap() if gap.Range().Length() != 0 { - return Iterator{}, gap + return reclaimIterator{}, gap } - return gap.PrevSegment(), GapIterator{} + return gap.PrevSegment(), reclaimGapIterator{} } // NextNonEmpty returns the iterated segment's successor if it is adjacent, or @@ -1370,12 +1259,12 @@ func (seg Iterator) PrevNonEmpty() (Iterator, GapIterator) { // Functions.MaxKey(), NextNonEmpty will return two terminal iterators. // Otherwise, exactly one of the iterators returned by NextNonEmpty will be // non-terminal. -func (seg Iterator) NextNonEmpty() (Iterator, GapIterator) { +func (seg reclaimIterator) NextNonEmpty() (reclaimIterator, reclaimGapIterator) { gap := seg.NextGap() if gap.Range().Length() != 0 { - return Iterator{}, gap + return reclaimIterator{}, gap } - return gap.NextSegment(), GapIterator{} + return gap.NextSegment(), reclaimGapIterator{} } // A GapIterator is conceptually one of: @@ -1396,77 +1285,77 @@ func (seg Iterator) NextNonEmpty() (Iterator, GapIterator) { // // Unless otherwise specified, any mutation of a set invalidates all existing // iterators into the set. -type GapIterator struct { +type reclaimGapIterator struct { // The representation of a GapIterator is identical to that of an Iterator, // except that index corresponds to positions between segments in the same // way as for node.children (see comment for node.nrSegments). - node *node + node *reclaimnode index int } // Ok returns true if the iterator is not terminal. All other methods are only // valid for non-terminal iterators. -func (gap GapIterator) Ok() bool { +func (gap reclaimGapIterator) Ok() bool { return gap.node != nil } // Range returns the range spanned by the iterated gap. -func (gap GapIterator) Range() Range { - return Range{gap.Start(), gap.End()} +func (gap reclaimGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} } // Start is equivalent to Range().Start, but should be preferred if only the // start of the range is needed. -func (gap GapIterator) Start() Key { +func (gap reclaimGapIterator) Start() uint64 { if ps := gap.PrevSegment(); ps.Ok() { return ps.End() } - return Functions{}.MinKey() + return reclaimSetFunctions{}.MinKey() } // End is equivalent to Range().End, but should be preferred if only the end of // the range is needed. -func (gap GapIterator) End() Key { +func (gap reclaimGapIterator) End() uint64 { if ns := gap.NextSegment(); ns.Ok() { return ns.Start() } - return Functions{}.MaxKey() + return reclaimSetFunctions{}.MaxKey() } // IsEmpty returns true if the iterated gap is empty (that is, the "gap" is // between two adjacent segments.) -func (gap GapIterator) IsEmpty() bool { +func (gap reclaimGapIterator) IsEmpty() bool { return gap.Range().Length() == 0 } // PrevSegment returns the segment immediately before the iterated gap. If no // such segment exists, PrevSegment returns a terminal iterator. -func (gap GapIterator) PrevSegment() Iterator { - return segmentBeforePosition(gap.node, gap.index) +func (gap reclaimGapIterator) PrevSegment() reclaimIterator { + return reclaimsegmentBeforePosition(gap.node, gap.index) } // NextSegment returns the segment immediately after the iterated gap. If no // such segment exists, NextSegment returns a terminal iterator. -func (gap GapIterator) NextSegment() Iterator { - return segmentAfterPosition(gap.node, gap.index) +func (gap reclaimGapIterator) NextSegment() reclaimIterator { + return reclaimsegmentAfterPosition(gap.node, gap.index) } // PrevGap returns the iterated gap's predecessor. If no such gap exists, // PrevGap returns a terminal iterator. -func (gap GapIterator) PrevGap() GapIterator { +func (gap reclaimGapIterator) PrevGap() reclaimGapIterator { seg := gap.PrevSegment() if !seg.Ok() { - return GapIterator{} + return reclaimGapIterator{} } return seg.PrevGap() } // NextGap returns the iterated gap's successor. If no such gap exists, NextGap // returns a terminal iterator. -func (gap GapIterator) NextGap() GapIterator { +func (gap reclaimGapIterator) NextGap() reclaimGapIterator { seg := gap.NextSegment() if !seg.Ok() { - return GapIterator{} + return reclaimGapIterator{} } return seg.NextGap() } @@ -1476,13 +1365,12 @@ func (gap GapIterator) NextGap() GapIterator { // include this gap itself). // // Precondition: trackGaps must be 1. -func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator { - if trackGaps != 1 { +func (gap reclaimGapIterator) NextLargeEnoughGap(minSize uint64) reclaimGapIterator { + if reclaimtrackGaps != 1 { panic("set is not tracking gaps") } if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { - // If gap is the trailing gap of an non-leaf node, - // translate it to the equivalent gap on leaf level. + gap.node = gap.NextSegment().node gap.index = 0 return gap.nextLargeEnoughGapHelper(minSize) @@ -1494,19 +1382,17 @@ func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator { // to do the real recursions. // // Preconditions: gap is NOT the trailing gap of a non-leaf node. -func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { - // Crawl up the tree if no large enough gap in current node or the - // current gap is the trailing one on leaf level. +func (gap reclaimGapIterator) nextLargeEnoughGapHelper(minSize uint64) reclaimGapIterator { + for gap.node != nil && (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { gap.node, gap.index = gap.node.parent, gap.node.parentIndex } - // If no large enough gap throughout the whole set, return a terminal - // gap iterator. + if gap.node == nil { - return GapIterator{} + return reclaimGapIterator{} } - // Iterate subsequent gaps. + gap.index++ for gap.index <= gap.node.nrSegments { if gap.node.hasChildren { @@ -1522,8 +1408,7 @@ func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { } gap.node, gap.index = gap.node.parent, gap.node.parentIndex if gap.node != nil && gap.index == gap.node.nrSegments { - // If gap is the trailing gap of a non-leaf node, crawl up to - // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex } return gap.nextLargeEnoughGapHelper(minSize) @@ -1534,13 +1419,12 @@ func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { // (does NOT include this gap itself). // // Precondition: trackGaps must be 1. -func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator { - if trackGaps != 1 { +func (gap reclaimGapIterator) PrevLargeEnoughGap(minSize uint64) reclaimGapIterator { + if reclaimtrackGaps != 1 { panic("set is not tracking gaps") } if gap.node != nil && gap.node.hasChildren && gap.index == 0 { - // If gap is the first gap of an non-leaf node, - // translate it to the equivalent gap on leaf level. + gap.node = gap.PrevSegment().node gap.index = gap.node.nrSegments return gap.prevLargeEnoughGapHelper(minSize) @@ -1552,19 +1436,17 @@ func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator { // to do the real recursions. // // Preconditions: gap is NOT the first gap of a non-leaf node. -func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { - // Crawl up the tree if no large enough gap in current node or the - // current gap is the first one on leaf level. +func (gap reclaimGapIterator) prevLargeEnoughGapHelper(minSize uint64) reclaimGapIterator { + for gap.node != nil && (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { gap.node, gap.index = gap.node.parent, gap.node.parentIndex } - // If no large enough gap throughout the whole set, return a terminal - // gap iterator. + if gap.node == nil { - return GapIterator{} + return reclaimGapIterator{} } - // Iterate previous gaps. + gap.index-- for gap.index >= 0 { if gap.node.hasChildren { @@ -1580,8 +1462,7 @@ func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { } gap.node, gap.index = gap.node.parent, gap.node.parentIndex if gap.node != nil && gap.index == 0 { - // If gap is the first gap of a non-leaf node, crawl up to - // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex } return gap.prevLargeEnoughGapHelper(minSize) @@ -1590,56 +1471,55 @@ func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { // segmentBeforePosition returns the predecessor segment of the position given // by n.children[i], which may or may not contain a child. If no such segment // exists, segmentBeforePosition returns a terminal iterator. -func segmentBeforePosition(n *node, i int) Iterator { +func reclaimsegmentBeforePosition(n *reclaimnode, i int) reclaimIterator { for i == 0 { if n.parent == nil { - return Iterator{} + return reclaimIterator{} } n, i = n.parent, n.parentIndex } - return Iterator{n, i - 1} + return reclaimIterator{n, i - 1} } // segmentAfterPosition returns the successor segment of the position given by // n.children[i], which may or may not contain a child. If no such segment // exists, segmentAfterPosition returns a terminal iterator. -func segmentAfterPosition(n *node, i int) Iterator { +func reclaimsegmentAfterPosition(n *reclaimnode, i int) reclaimIterator { for i == n.nrSegments { if n.parent == nil { - return Iterator{} + return reclaimIterator{} } n, i = n.parent, n.parentIndex } - return Iterator{n, i} + return reclaimIterator{n, i} } -func zeroValueSlice(slice []Value) { - // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here. +func reclaimzeroValueSlice(slice []reclaimSetValue) { + for i := range slice { - Functions{}.ClearValue(&slice[i]) + reclaimSetFunctions{}.ClearValue(&slice[i]) } } -func zeroNodeSlice(slice []*node) { +func reclaimzeroNodeSlice(slice []*reclaimnode) { for i := range slice { slice[i] = nil } } // String stringifies a Set for debugging. -func (s *Set) String() string { +func (s *reclaimSet) String() string { return s.root.String() } // String stringifies a node (and all of its children) for debugging. -func (n *node) String() string { +func (n *reclaimnode) String() string { var buf bytes.Buffer n.writeDebugString(&buf, "") return buf.String() } -func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { +func (n *reclaimnode) writeDebugString(buf *bytes.Buffer, prefix string) { if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { buf.WriteString(prefix) buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) @@ -1655,7 +1535,7 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { } buf.WriteString(prefix) if n.hasChildren { - if trackGaps != 0 { + if reclaimtrackGaps != 0 { buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) } else { buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) @@ -1674,16 +1554,16 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { // for save/restore and the layout here is optimized for that. // // +stateify savable -type SegmentDataSlices struct { - Start []Key - End []Key - Values []Value +type reclaimSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []reclaimSetValue } // ExportSortedSlices returns a copy of all segments in the given set, in // ascending key order. -func (s *Set) ExportSortedSlices() *SegmentDataSlices { - var sds SegmentDataSlices +func (s *reclaimSet) ExportSortedSlices() *reclaimSegmentDataSlices { + var sds reclaimSegmentDataSlices for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { sds.Start = append(sds.Start, seg.Start()) sds.End = append(sds.End, seg.End()) @@ -1702,13 +1582,13 @@ func (s *Set) ExportSortedSlices() *SegmentDataSlices { // * sds must represent a valid set (the segments in sds must have valid // lengths that do not overlap). // * The segments in sds must be sorted in ascending key order. -func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { +func (s *reclaimSet) ImportSortedSlices(sds *reclaimSegmentDataSlices) error { if !s.IsEmpty() { return fmt.Errorf("cannot import into non-empty set %v", s) } gap := s.FirstGap() for i := range sds.Start { - r := Range{sds.Start[i], sds.End[i]} + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} if !gap.Range().IsSupersetOf(r) { return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) } @@ -1723,9 +1603,9 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { // // This should be used only for testing, and has been added to this package for // templating convenience. -func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error { +func (s *reclaimSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.FileRange, reclaimSetValue) error) error { havePrev := false - prev := Key(0) + prev := uint64(0) nrSegments := 0 for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { next := seg.Start() @@ -1750,9 +1630,18 @@ func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Va // countSegments counts the number of segments in the set. // // Similar to Check, this should only be used for testing. -func (s *Set) countSegments() (segments int) { +func (s *reclaimSet) countSegments() (segments int) { for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { segments++ } return segments } +func (s *reclaimSet) saveRoot() *reclaimSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *reclaimSet) loadRoot(sds *reclaimSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/usage_set.go b/pkg/sentry/pgalloc/usage_set.go new file mode 100644 index 000000000..8d96e817a --- /dev/null +++ b/pkg/sentry/pgalloc/usage_set.go @@ -0,0 +1,1647 @@ +package pgalloc + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const usagetrackGaps = 1 + +var _ = uint8(usagetrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type usagedynamicGap [usagetrackGaps]uint64 + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *usagedynamicGap) Get() uint64 { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *usagedynamicGap) Set(v uint64) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + usageminDegree = 10 + + usagemaxDegree = 2 * usageminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type usageSet struct { + root usagenode `state:".(*usageSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *usageSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *usageSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *usageSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *usageSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *usageSet) FirstSegment() usageIterator { + if s.root.nrSegments == 0 { + return usageIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *usageSet) LastSegment() usageIterator { + if s.root.nrSegments == 0 { + return usageIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *usageSet) FirstGap() usageGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return usageGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *usageSet) LastGap() usageGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return usageGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *usageSet) Find(key uint64) (usageIterator, usageGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return usageIterator{n, i}, usageGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return usageIterator{}, usageGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *usageSet) FindSegment(key uint64) usageIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *usageSet) LowerBoundSegment(min uint64) usageIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *usageSet) UpperBoundSegment(max uint64) usageIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *usageSet) FindGap(key uint64) usageGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *usageSet) LowerBoundGap(min uint64) usageGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *usageSet) UpperBoundGap(max uint64) usageGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *usageSet) Add(r __generics_imported0.FileRange, val usageInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *usageSet) AddWithoutMerging(r __generics_imported0.FileRange, val usageInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *usageSet) Insert(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := usagetrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (usageSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := usagetrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *usageSet) InsertWithoutMerging(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *usageSet) InsertWithoutMergingUnchecked(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := usagetrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return usageIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *usageSet) Remove(seg usageIterator) usageGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if usagetrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + usageSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if usagetrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(usageGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *usageSet) RemoveAll() { + s.root = usagenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *usageSet) RemoveRange(r __generics_imported0.FileRange) usageGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *usageSet) Merge(first, second usageIterator) usageIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *usageSet) MergeUnchecked(first, second usageIterator) usageIterator { + if first.End() == second.Start() { + if mval, ok := (usageSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return usageIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *usageSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *usageSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *usageSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *usageSet) Split(seg usageIterator, split uint64) (usageIterator, usageIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *usageSet) SplitUnchecked(seg usageIterator, split uint64) (usageIterator, usageIterator) { + val1, val2 := (usageSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *usageSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *usageSet) Isolate(seg usageIterator, r __generics_imported0.FileRange) usageIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *usageSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg usageIterator)) usageGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return usageGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return usageGapIterator{} + } + } +} + +// +stateify savable +type usagenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *usagenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap usagedynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [usagemaxDegree - 1]__generics_imported0.FileRange + values [usagemaxDegree - 1]usageInfo + children [usagemaxDegree]*usagenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *usagenode) firstSegment() usageIterator { + for n.hasChildren { + n = n.children[0] + } + return usageIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *usagenode) lastSegment() usageIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return usageIterator{n, n.nrSegments - 1} +} + +func (n *usagenode) prevSibling() *usagenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *usagenode) nextSibling() *usagenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *usagenode) rebalanceBeforeInsert(gap usageGapIterator) usageGapIterator { + if n.nrSegments < usagemaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:usageminDegree-1], n.keys[:usageminDegree-1]) + copy(left.values[:usageminDegree-1], n.values[:usageminDegree-1]) + copy(right.keys[:usageminDegree-1], n.keys[usageminDegree:]) + copy(right.values[:usageminDegree-1], n.values[usageminDegree:]) + n.keys[0], n.values[0] = n.keys[usageminDegree-1], n.values[usageminDegree-1] + usagezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:usageminDegree], n.children[:usageminDegree]) + copy(right.children[:usageminDegree], n.children[usageminDegree:]) + usagezeroNodeSlice(n.children[2:]) + for i := 0; i < usageminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if usagetrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < usageminDegree { + return usageGapIterator{left, gap.index} + } + return usageGapIterator{right, gap.index - usageminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[usageminDegree-1], n.values[usageminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:usageminDegree-1], n.keys[usageminDegree:]) + copy(sibling.values[:usageminDegree-1], n.values[usageminDegree:]) + usagezeroValueSlice(n.values[usageminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:usageminDegree], n.children[usageminDegree:]) + usagezeroNodeSlice(n.children[usageminDegree:]) + for i := 0; i < usageminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = usageminDegree - 1 + + if usagetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < usageminDegree { + return gap + } + return usageGapIterator{sibling, gap.index - usageminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *usagenode) rebalanceAfterRemove(gap usageGapIterator) usageGapIterator { + for { + if n.nrSegments >= usageminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= usageminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if usagetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return usageGapIterator{n, 0} + } + if gap.node == n { + return usageGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= usageminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if usagetrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return usageGapIterator{n, n.nrSegments} + } + return usageGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return usageGapIterator{p, gap.index} + } + if gap.node == right { + return usageGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *usagenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = usageGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + usageSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if usagetrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *usagenode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *usagenode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *usagenode) calculateMaxGapLeaf() uint64 { + max := usageGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (usageGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *usagenode) calculateMaxGapInternal() uint64 { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *usagenode) searchFirstLargeEnoughGap(minSize uint64) usageGapIterator { + if n.maxGap.Get() < minSize { + return usageGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := usageGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *usagenode) searchLastLargeEnoughGap(minSize uint64) usageGapIterator { + if n.maxGap.Get() < minSize { + return usageGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := usageGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type usageIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *usagenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg usageIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg usageIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg usageIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg usageIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg usageIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg usageIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg usageIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg usageIterator) Value() usageInfo { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg usageIterator) ValuePtr() *usageInfo { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg usageIterator) SetValue(val usageInfo) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg usageIterator) PrevSegment() usageIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return usageIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return usageIterator{} + } + return usagesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg usageIterator) NextSegment() usageIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return usageIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return usageIterator{} + } + return usagesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg usageIterator) PrevGap() usageGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return usageGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg usageIterator) NextGap() usageGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return usageGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg usageIterator) PrevNonEmpty() (usageIterator, usageGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return usageIterator{}, gap + } + return gap.PrevSegment(), usageGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg usageIterator) NextNonEmpty() (usageIterator, usageGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return usageIterator{}, gap + } + return gap.NextSegment(), usageGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type usageGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *usagenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap usageGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap usageGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap usageGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return usageSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap usageGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return usageSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap usageGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap usageGapIterator) PrevSegment() usageIterator { + return usagesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap usageGapIterator) NextSegment() usageIterator { + return usagesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap usageGapIterator) PrevGap() usageGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return usageGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap usageGapIterator) NextGap() usageGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return usageGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap usageGapIterator) NextLargeEnoughGap(minSize uint64) usageGapIterator { + if usagetrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap usageGapIterator) nextLargeEnoughGapHelper(minSize uint64) usageGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return usageGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap usageGapIterator) PrevLargeEnoughGap(minSize uint64) usageGapIterator { + if usagetrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap usageGapIterator) prevLargeEnoughGapHelper(minSize uint64) usageGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return usageGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func usagesegmentBeforePosition(n *usagenode, i int) usageIterator { + for i == 0 { + if n.parent == nil { + return usageIterator{} + } + n, i = n.parent, n.parentIndex + } + return usageIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func usagesegmentAfterPosition(n *usagenode, i int) usageIterator { + for i == n.nrSegments { + if n.parent == nil { + return usageIterator{} + } + n, i = n.parent, n.parentIndex + } + return usageIterator{n, i} +} + +func usagezeroValueSlice(slice []usageInfo) { + + for i := range slice { + usageSetFunctions{}.ClearValue(&slice[i]) + } +} + +func usagezeroNodeSlice(slice []*usagenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *usageSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *usagenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *usagenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if usagetrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type usageSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []usageInfo +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *usageSet) ExportSortedSlices() *usageSegmentDataSlices { + var sds usageSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *usageSet) ImportSortedSlices(sds *usageSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *usageSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.FileRange, usageInfo) error) error { + havePrev := false + prev := uint64(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *usageSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *usageSet) saveRoot() *usageSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *usageSet) loadRoot(sds *usageSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD deleted file mode 100644 index 7125657b3..000000000 --- a/pkg/sentry/platform/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "platform", - srcs = [ - "context.go", - "mmap_min_addr.go", - "platform.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/hostarch", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/hostmm", - "//pkg/sentry/memmap", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD deleted file mode 100644 index 83b385f14..000000000 --- a/pkg/sentry/platform/interrupt/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "interrupt", - srcs = [ - "interrupt.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sync"], -) - -go_test( - name = "interrupt_test", - size = "small", - srcs = ["interrupt_test.go"], - library = ":interrupt", -) diff --git a/pkg/sentry/platform/interrupt/interrupt_state_autogen.go b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go new file mode 100644 index 000000000..1336e5f01 --- /dev/null +++ b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package interrupt diff --git a/pkg/sentry/platform/interrupt/interrupt_test.go b/pkg/sentry/platform/interrupt/interrupt_test.go deleted file mode 100644 index 0ecdf6e7a..000000000 --- a/pkg/sentry/platform/interrupt/interrupt_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package interrupt - -import ( - "testing" -) - -type countingReceiver struct { - interrupts int -} - -// NotifyInterrupt implements Receiver.NotifyInterrupt. -func (r *countingReceiver) NotifyInterrupt() { - r.interrupts++ -} - -func TestSingleInterruptBeforeEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - f.NotifyInterrupt() - // The interrupt should cause the first Enable to fail. - if f.Enable(&r) { - f.Disable() - t.Fatalf("Enable: got true, wanted false") - } - // The failing Enable "acknowledges" the interrupt, allowing future Enables - // to succeed. - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - f.Disable() -} - -func TestMultipleInterruptsBeforeEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - f.NotifyInterrupt() - f.NotifyInterrupt() - // The interrupts should cause the first Enable to fail. - if f.Enable(&r) { - f.Disable() - t.Fatalf("Enable: got true, wanted false") - } - // Interrupts are deduplicated while the Forwarder is disabled, so the - // failing Enable "acknowledges" all interrupts, allowing future Enables to - // succeed. - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - f.Disable() -} - -func TestSingleInterruptAfterEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - defer f.Disable() - f.NotifyInterrupt() - if r.interrupts != 1 { - t.Errorf("interrupts: got %d, wanted 1", r.interrupts) - } -} - -func TestMultipleInterruptsAfterEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - defer f.Disable() - f.NotifyInterrupt() - f.NotifyInterrupt() - if r.interrupts != 2 { - t.Errorf("interrupts: got %d, wanted 2", r.interrupts) - } -} diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD deleted file mode 100644 index 8a490b3de..000000000 --- a/pkg/sentry/platform/kvm/BUILD +++ /dev/null @@ -1,101 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "kvm", - srcs = [ - "address_space.go", - "address_space_amd64.go", - "address_space_arm64.go", - "bluepill.go", - "bluepill_allocator.go", - "bluepill_amd64.go", - "bluepill_amd64_unsafe.go", - "bluepill_arm64.go", - "bluepill_arm64.s", - "bluepill_arm64_unsafe.go", - "bluepill_fault.go", - "bluepill_impl_amd64.s", - "bluepill_unsafe.go", - "context.go", - "filters_amd64.go", - "filters_arm64.go", - "kvm.go", - "kvm_amd64.go", - "kvm_amd64_unsafe.go", - "kvm_arm64.go", - "kvm_arm64_unsafe.go", - "kvm_const.go", - "kvm_const_arm64.go", - "machine.go", - "machine_amd64.go", - "machine_amd64_unsafe.go", - "machine_arm64.go", - "machine_arm64_unsafe.go", - "machine_unsafe.go", - "physical_map.go", - "physical_map_amd64.go", - "physical_map_arm64.go", - "virtual_map.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/context", - "//pkg/cpuid", - "//pkg/hostarch", - "//pkg/log", - "//pkg/procid", - "//pkg/ring0", - "//pkg/ring0/pagetables", - "//pkg/safecopy", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/arch/fpu", - "//pkg/sentry/memmap", - "//pkg/sentry/platform", - "//pkg/sentry/platform/interrupt", - "//pkg/sentry/time", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "kvm_test", - srcs = [ - "kvm_amd64_test.go", - "kvm_amd64_test.s", - "kvm_arm64_test.go", - "kvm_test.go", - "virtual_map_test.go", - ], - library = ":kvm", - tags = [ - "manual", - "nogotsan", - "requires-kvm", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/hostarch", - "//pkg/ring0", - "//pkg/ring0/pagetables", - "//pkg/sentry/arch", - "//pkg/sentry/arch/fpu", - "//pkg/sentry/platform", - "//pkg/sentry/platform/kvm/testutil", - "//pkg/sentry/time", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -genrule( - name = "bluepill_impl_amd64", - srcs = ["bluepill_amd64.s"], - outs = ["bluepill_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/ring0/gen_offsets) && cat $(SRCS)) > $@", - tools = ["//pkg/ring0/gen_offsets"], -) diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_impl_amd64.s index c2a1dca11..99f254342 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_impl_amd64.s @@ -1,3 +1,73 @@ +// build +amd64 + +// Automatically generated, do not edit. + +// CPU offsets. +#define CPU_REGISTERS 0x28 +#define CPU_ERROR_CODE 0x10 +#define CPU_ERROR_TYPE 0x18 +#define CPU_ENTRY 0x20 + +// CPU entry offsets. +#define ENTRY_SCRATCH0 0x100 +#define ENTRY_STACK_TOP 0x108 +#define ENTRY_CPU_SELF 0x110 +#define ENTRY_KERNEL_CR3 0x118 + +// Bits. +#define _RFLAGS_IF 0x200 +#define _RFLAGS_IOPL0 0x1000 +#define _KERNEL_FLAGS 0x02 + +// Vectors. +#define DivideByZero 0x00 +#define Debug 0x01 +#define NMI 0x02 +#define Breakpoint 0x03 +#define Overflow 0x04 +#define BoundRangeExceeded 0x05 +#define InvalidOpcode 0x06 +#define DeviceNotAvailable 0x07 +#define DoubleFault 0x08 +#define CoprocessorSegmentOverrun 0x09 +#define InvalidTSS 0x0a +#define SegmentNotPresent 0x0b +#define StackSegmentFault 0x0c +#define GeneralProtectionFault 0x0d +#define PageFault 0x0e +#define X87FloatingPointException 0x10 +#define AlignmentCheck 0x11 +#define MachineCheck 0x12 +#define SIMDFloatingPointException 0x13 +#define VirtualizationException 0x14 +#define SecurityException 0x1e +#define SyscallInt80 0x80 +#define Syscall 0x100 + +// Ptrace registers. +#define PTRACE_R15 0x00 +#define PTRACE_R14 0x08 +#define PTRACE_R13 0x10 +#define PTRACE_R12 0x18 +#define PTRACE_RBP 0x20 +#define PTRACE_RBX 0x28 +#define PTRACE_R11 0x30 +#define PTRACE_R10 0x38 +#define PTRACE_R9 0x40 +#define PTRACE_R8 0x48 +#define PTRACE_RAX 0x50 +#define PTRACE_RCX 0x58 +#define PTRACE_RDX 0x60 +#define PTRACE_RSI 0x68 +#define PTRACE_RDI 0x70 +#define PTRACE_ORIGRAX 0x78 +#define PTRACE_RIP 0x80 +#define PTRACE_CS 0x88 +#define PTRACE_FLAGS 0x90 +#define PTRACE_RSP 0x98 +#define PTRACE_SS 0xa0 +#define PTRACE_FS_BASE 0xa8 +#define PTRACE_GS_BASE 0xb0 // Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/sentry/platform/kvm/kvm_amd64_state_autogen.go b/pkg/sentry/platform/kvm/kvm_amd64_state_autogen.go new file mode 100644 index 000000000..1bc956971 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go deleted file mode 100644 index c3fbbdc75..000000000 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build amd64 -// +build amd64 - -package kvm - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/ring0" - "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" -) - -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.AddrOfTwiddleSegments(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -// stmxcsr reads the MXCSR control and status register. -func stmxcsr(addr *uint32) - -func TestMXCSR(t *testing.T) { - applicationTest(t, true, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - switchOpts := ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - } - - const mxcsrControllMask = uint32(0x1f80) - mxcsrBefore := uint32(0) - mxcsrAfter := uint32(0) - stmxcsr(&mxcsrBefore) - if mxcsrBefore == 0 { - // goruntime sets mxcsr to 0x1f80 and it never changes - // the control configuration. - panic("mxcsr is zero") - } - switchOpts.FloatingPointState.SetMXCSR(0) - if _, err := c.SwitchToUser( - switchOpts, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall failed: %v", err) - } - stmxcsr(&mxcsrAfter) - if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask { - t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter) - } - return false - }) -} diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.s b/pkg/sentry/platform/kvm/kvm_amd64_test.s deleted file mode 100644 index 8e9079867..000000000 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.s +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// stmxcsr reads the MXCSR control and status register. -TEXT ·stmxcsr(SB),NOSPLIT,$0-8 - MOVQ addr+0(FP), SI - STMXCSR (SI) - RET diff --git a/pkg/sentry/platform/kvm/kvm_amd64_unsafe_state_autogen.go b/pkg/sentry/platform/kvm/kvm_amd64_unsafe_state_autogen.go new file mode 100644 index 000000000..1bc956971 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_arm64_state_autogen.go b/pkg/sentry/platform/kvm/kvm_arm64_state_autogen.go new file mode 100644 index 000000000..fc675a238 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_arm64_test.go b/pkg/sentry/platform/kvm/kvm_arm64_test.go deleted file mode 100644 index b53e354da..000000000 --- a/pkg/sentry/platform/kvm/kvm_arm64_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 - -package kvm - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" -) - -func TestKernelTLS(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - if !testutil.TLSWorks() { - t.Errorf("tls does not work, and it should!") - } - }) -} diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe_state_autogen.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe_state_autogen.go new file mode 100644 index 000000000..fc675a238 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_state_autogen.go b/pkg/sentry/platform/kvm/kvm_state_autogen.go new file mode 100644 index 000000000..8d85b96d0 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package kvm diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go deleted file mode 100644 index 3a30286e2..000000000 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ /dev/null @@ -1,529 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvm - -import ( - "math/rand" - "reflect" - "sync/atomic" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/ring0" - "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/arch/fpu" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - ktime "gvisor.dev/gvisor/pkg/sentry/time" -) - -var dummyFPState = fpu.NewState() - -type testHarness interface { - Errorf(format string, args ...interface{}) - Fatalf(format string, args ...interface{}) -} - -func kvmTest(t testHarness, setup func(*KVM), fn func(*vCPU) bool) { - // Create the machine. - deviceFile, err := OpenDevice() - if err != nil { - t.Fatalf("error opening device file: %v", err) - } - k, err := New(deviceFile) - if err != nil { - t.Fatalf("error creating KVM instance: %v", err) - } - defer k.machine.Destroy() - - // Call additional setup. - if setup != nil { - setup(k) - } - - var c *vCPU // For recovery. - defer func() { - redpill() - if c != nil { - k.machine.Put(c) - } - }() - for { - c = k.machine.Get() - if !fn(c) { - break - } - - // We put the vCPU here and clear the value so that the - // deferred recovery will not re-put it above. - k.machine.Put(c) - c = nil - } -} - -func bluepillTest(t testHarness, fn func(*vCPU)) { - kvmTest(t, nil, func(c *vCPU) bool { - bluepill(c) - fn(c) - return false - }) -} - -func TestKernelSyscall(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - redpill() // Leave guest mode. - if got := atomic.LoadUint32(&c.state); got != vCPUUser { - t.Errorf("vCPU not in ready state: got %v", got) - } - }) -} - -func hostFault() { - defer func() { - recover() - }() - var foo *int - *foo = 0 -} - -func TestKernelFault(t *testing.T) { - hostFault() // Ensure recovery works. - bluepillTest(t, func(c *vCPU) { - hostFault() - if got := atomic.LoadUint32(&c.state); got != vCPUUser { - t.Errorf("vCPU not in ready state: got %v", got) - } - }) -} - -func TestKernelFloatingPoint(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - if !testutil.FloatingPointWorks() { - t.Errorf("floating point does not work, and it should!") - } - }) -} - -func applicationTest(t testHarness, useHostMappings bool, targetFn uintptr, fn func(*vCPU, *arch.Registers, *pagetables.PageTables) bool) { - // Initialize registers & page tables. - var ( - regs arch.Registers - pt *pagetables.PageTables - ) - testutil.SetTestTarget(®s, targetFn) - - kvmTest(t, func(k *KVM) { - // Create new page tables. - as, _, err := k.NewAddressSpace(nil /* invalidator */) - if err != nil { - t.Fatalf("can't create new address space: %v", err) - } - pt = as.(*addressSpace).pageTables - - if useHostMappings { - // Apply the physical mappings to these page tables. - // (This is normally dangerous, since they point to - // physical pages that may not exist. This shouldn't be - // done for regular user code, but is fine for test - // purposes.) - applyPhysicalRegions(func(pr physicalRegion) bool { - pt.Map(hostarch.Addr(pr.virtual), pr.length, pagetables.MapOpts{ - AccessType: hostarch.AnyAccess, - User: true, - }, pr.physical) - return true // Keep iterating. - }) - } - }, func(c *vCPU) bool { - // Invoke the function with the extra data. - return fn(c, ®s, pt) - }) -} - -func TestApplicationSyscall(t *testing.T) { - applicationTest(t, true, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall with full restore failed: %v", err) - } - return false - }) - applicationTest(t, true, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall with partial restore failed: %v", err) - } - return false - }) -} - -func TestApplicationFault(t *testing.T) { - applicationTest(t, true, testutil.AddrOfTouch(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, nil) // Cause fault. - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(unix.SIGSEGV) { - t.Errorf("application fault with full restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) - } - return false - }) - applicationTest(t, true, testutil.AddrOfTouch(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, nil) // Cause fault. - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(unix.SIGSEGV) { - t.Errorf("application fault with partial restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) - } - return false - }) -} - -func TestRegistersSyscall(t *testing.T) { - applicationTest(t, true, testutil.AddrOfTwiddleRegsSyscall(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestRegs(regs) // Fill values for all registers. - for { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application register check with partial restore got unexpected error: %v", err) - } - if err := testutil.CheckTestRegs(regs, false); err != nil { - t.Errorf("application register check with partial restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestRegistersFault(t *testing.T) { - applicationTest(t, true, testutil.AddrOfTwiddleRegsFault(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestRegs(regs) // Fill values for all registers. - for { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(unix.SIGSEGV) { - t.Errorf("application register check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestRegs(regs, true); err != nil { - t.Errorf("application register check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestBounce(t *testing.T) { - applicationTest(t, true, testutil.AddrOfSpinLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - go func() { - time.Sleep(time.Millisecond) - c.BounceToKernel() - }() - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - return false - }) - applicationTest(t, true, testutil.AddrOfSpinLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - go func() { - time.Sleep(time.Millisecond) - c.BounceToKernel() - }() - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application full restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - return false - }) -} - -func TestBounceStress(t *testing.T) { - applicationTest(t, true, testutil.AddrOfSpinLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - randomSleep := func() { - // O(hundreds of microseconds) is appropriate to ensure - // different overlaps and different schedules. - if n := rand.Intn(1000); n > 100 { - time.Sleep(time.Duration(n) * time.Microsecond) - } - } - for i := 0; i < 1000; i++ { - // Start an asynchronously executing goroutine that - // calls Bounce at pseudo-random point in time. - // This should wind up calling Bounce when the - // kernel is in various stages of the switch. - go func() { - randomSleep() - c.BounceToKernel() - }() - randomSleep() - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - c.unlock() - randomSleep() - c.lock() - } - return false - }) -} - -func TestInvalidate(t *testing.T) { - var data uintptr // Used below. - applicationTest(t, true, testutil.AddrOfTouch(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, &data) // Read legitimate value. - for { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application partial restore: got %v, wanted nil", err) - } - break // Done. - } - // Unmap the page containing data & invalidate. - pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize) - for { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - Flush: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != platform.ErrContextSignal { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextSignal) - } - break // Success. - } - return false - }) -} - -// IsFault returns true iff the given signal represents a fault. -func IsFault(err error, si *linux.SignalInfo) bool { - return err == platform.ErrContextSignal && si.Signo == int32(unix.SIGSEGV) -} - -func TestEmptyAddressSpace(t *testing.T) { - applicationTest(t, false, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if !IsFault(err, &si) { - t.Errorf("first fault with partial restore failed got %v", err) - t.Logf("registers: %#v", ®s) - } - return false - }) - applicationTest(t, false, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if !IsFault(err, &si) { - t.Errorf("first fault with full restore failed got %v", err) - t.Logf("registers: %#v", ®s) - } - return false - }) -} - -func TestWrongVCPU(t *testing.T) { - kvmTest(t, nil, func(c1 *vCPU) bool { - kvmTest(t, nil, func(c2 *vCPU) bool { - // Basic test, one then the other. - bluepill(c1) - bluepill(c2) - if c1.guestExits == 0 { - // Check: vCPU1 will exit due to redpill() in bluepill(c2). - // Don't allow the test to proceed if this fails. - t.Fatalf("wrong vCPU#1 exits: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - - // Alternate vCPUs; we expect to need to trigger the - // wrong vCPU path on each switch. - for i := 0; i < 100; i++ { - bluepill(c1) - bluepill(c2) - } - if count := c1.guestExits; count < 90 { - t.Errorf("wrong vCPU#1 exits: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - if count := c2.guestExits; count < 90 { - t.Errorf("wrong vCPU#2 exits: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - return false - }) - return false - }) - kvmTest(t, nil, func(c1 *vCPU) bool { - kvmTest(t, nil, func(c2 *vCPU) bool { - bluepill(c1) - bluepill(c2) - return false - }) - return false - }) -} - -func TestRdtsc(t *testing.T) { - var i int // Iteration count. - kvmTest(t, nil, func(c *vCPU) bool { - start := ktime.Rdtsc() - bluepill(c) - guest := ktime.Rdtsc() - redpill() - end := ktime.Rdtsc() - if start > guest || guest > end { - t.Errorf("inconsistent time: start=%d, guest=%d, end=%d", start, guest, end) - } - i++ - return i < 100 - }) -} - -func BenchmarkApplicationSyscall(b *testing.B) { - var ( - i int // Iteration includes machine.Get() / machine.Put(). - a int // Count for ErrContextInterrupt. - ) - applicationTest(b, true, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - a++ - return true // Ignore. - } else if err != nil { - b.Fatalf("benchmark failed: %v", err) - } - i++ - return i < b.N - }) - if a != 0 { - b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) - } -} - -func BenchmarkKernelSyscall(b *testing.B) { - // Note that the target passed here is irrelevant, we never execute SwitchToUser. - applicationTest(b, true, testutil.AddrOfGetpid(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - // iteration does not include machine.Get() / machine.Put(). - for i := 0; i < b.N; i++ { - testutil.Getpid() - } - return false - }) -} - -func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { - // see BenchmarkApplicationSyscall. - var ( - i int - a int - ) - applicationTest(b, true, testutil.AddrOfSyscallLoop(), func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si linux.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: &dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - a++ - return true // Ignore. - } else if err != nil { - b.Fatalf("benchmark failed: %v", err) - } - // This will intentionally cause the world switch. By executing - // a host syscall here, we force the transition between guest - // and host mode. - testutil.Getpid() - i++ - return i < b.N - }) - if a != 0 { - b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) - } -} diff --git a/pkg/sentry/platform/kvm/kvm_unsafe_state_autogen.go b/pkg/sentry/platform/kvm/kvm_unsafe_state_autogen.go new file mode 100644 index 000000000..ca1360c0c --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.12 && go1.12 +// +build go1.12,go1.12 + +package kvm diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD deleted file mode 100644 index f7feb8683..000000000 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "testutil.go", - "testutil_amd64.go", - "testutil_amd64.s", - "testutil_arm64.go", - "testutil_arm64.s", - ], - visibility = ["//pkg/sentry/platform/kvm:__pkg__"], - deps = ["//pkg/sentry/arch"], -) diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go deleted file mode 100644 index d8c273796..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides common assembly stubs for testing. -package testutil - -import ( - "fmt" - "strings" -) - -// Getpid executes a trivial system call. -func Getpid() - -// AddrOfGetpid returns the address of Getpid. -// -// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal -// wrapper function rather than the function itself. We must reference from -// assembly to get the ABI0 (i.e., primary) address. -func AddrOfGetpid() uintptr - -// AddrOfTouch returns the address of a function that touches the value in the -// first register. -func AddrOfTouch() uintptr -func touch() - -// AddrOfSyscallLoop returns the address of a function that executes a syscall -// and loops. -func AddrOfSyscallLoop() uintptr -func syscallLoop() - -// AddrOfSpinLoop returns the address of a function that spins on the CPU. -func AddrOfSpinLoop() uintptr -func spinLoop() - -// AddrOfHaltLoop returns the address of a function that immediately halts and -// loops. -func AddrOfHaltLoop() uintptr -func haltLoop() - -// AddrOfTwiddleRegsFault returns the address of a function that twiddles -// registers then faults. -func AddrOfTwiddleRegsFault() uintptr -func twiddleRegsFault() - -// AddrOfTwiddleRegsSyscall returns the address of a function that twiddles -// registers then executes a syscall. -func AddrOfTwiddleRegsSyscall() uintptr -func twiddleRegsSyscall() - -// FloatingPointWorks is a floating point test. -// -// It returns true or false. -func FloatingPointWorks() bool - -// RegisterMismatchError is used for checking registers. -type RegisterMismatchError []string - -// Error returns a human-readable error. -func (r RegisterMismatchError) Error() string { - return strings.Join([]string(r), ";") -} - -// addRegisterMisatch allows simple chaining of register mismatches. -func addRegisterMismatch(err error, reg string, got, expected interface{}) error { - errStr := fmt.Sprintf("%s got %08x, expected %08x", reg, got, expected) - switch r := err.(type) { - case nil: - // Return a new register mismatch. - return RegisterMismatchError{errStr} - case RegisterMismatchError: - // Append the error. - r = append(r, errStr) - return r - default: - // Leave as is. - return err - } -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go deleted file mode 100644 index 98c52b2f5..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build amd64 -// +build amd64 - -package testutil - -import ( - "reflect" - - "gvisor.dev/gvisor/pkg/sentry/arch" -) - -// AddrOfTwiddleSegments return the address of a function that reads segments -// into known registers. -func AddrOfTwiddleSegments() uintptr -func twiddleSegments() - -// SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *arch.Registers, fn uintptr) { - regs.Rip = uint64(fn) -} - -// SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *arch.Registers, target *uintptr) { - if target != nil { - regs.Rax = uint64(reflect.ValueOf(target).Pointer()) - } else { - regs.Rax = 0 - } -} - -// RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *arch.Registers) { - regs.Rip -= 2 -} - -// SetTestRegs initializes registers to known values. -func SetTestRegs(regs *arch.Registers) { - regs.R15 = 0x15 - regs.R14 = 0x14 - regs.R13 = 0x13 - regs.R12 = 0x12 - regs.Rbp = 0xb9 - regs.Rbx = 0xb4 - regs.R11 = 0x11 - regs.R10 = 0x10 - regs.R9 = 0x09 - regs.R8 = 0x08 - regs.Rax = 0x44 - regs.Rcx = 0xc4 - regs.Rdx = 0xd4 - regs.Rsi = 0x51 - regs.Rdi = 0xd1 - regs.Rsp = 0x59 -} - -// CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *arch.Registers, full bool) (err error) { - if need := ^uint64(0x15); regs.R15 != need { - err = addRegisterMismatch(err, "R15", regs.R15, need) - } - if need := ^uint64(0x14); regs.R14 != need { - err = addRegisterMismatch(err, "R14", regs.R14, need) - } - if need := ^uint64(0x13); regs.R13 != need { - err = addRegisterMismatch(err, "R13", regs.R13, need) - } - if need := ^uint64(0x12); regs.R12 != need { - err = addRegisterMismatch(err, "R12", regs.R12, need) - } - if need := ^uint64(0xb9); regs.Rbp != need { - err = addRegisterMismatch(err, "Rbp", regs.Rbp, need) - } - if need := ^uint64(0xb4); regs.Rbx != need { - err = addRegisterMismatch(err, "Rbx", regs.Rbx, need) - } - if need := ^uint64(0x10); regs.R10 != need { - err = addRegisterMismatch(err, "R10", regs.R10, need) - } - if need := ^uint64(0x09); regs.R9 != need { - err = addRegisterMismatch(err, "R9", regs.R9, need) - } - if need := ^uint64(0x08); regs.R8 != need { - err = addRegisterMismatch(err, "R8", regs.R8, need) - } - if need := ^uint64(0x44); regs.Rax != need { - err = addRegisterMismatch(err, "Rax", regs.Rax, need) - } - if need := ^uint64(0xd4); regs.Rdx != need { - err = addRegisterMismatch(err, "Rdx", regs.Rdx, need) - } - if need := ^uint64(0x51); regs.Rsi != need { - err = addRegisterMismatch(err, "Rsi", regs.Rsi, need) - } - if need := ^uint64(0xd1); regs.Rdi != need { - err = addRegisterMismatch(err, "Rdi", regs.Rdi, need) - } - if need := ^uint64(0x59); regs.Rsp != need { - err = addRegisterMismatch(err, "Rsp", regs.Rsp, need) - } - // Rcx & R11 are ignored if !full is set. - if need := ^uint64(0x11); full && regs.R11 != need { - err = addRegisterMismatch(err, "R11", regs.R11, need) - } - if need := ^uint64(0xc4); full && regs.Rcx != need { - err = addRegisterMismatch(err, "Rcx", regs.Rcx, need) - } - return -} - -var fsData uint64 = 0x55 -var gsData uint64 = 0x85 - -// SetTestSegments initializes segments to known values. -func SetTestSegments(regs *arch.Registers) { - regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer()) - regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer()) -} - -// CheckTestSegments checks that registers were twiddled per TwiddleSegments. -func CheckTestSegments(regs *arch.Registers) (err error) { - if regs.Rax != fsData { - err = addRegisterMismatch(err, "Rax", regs.Rax, fsData) - } - if regs.Rbx != gsData { - err = addRegisterMismatch(err, "Rbx", regs.Rcx, gsData) - } - return -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s deleted file mode 100644 index 65e7c05ea..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -// test_util_amd64.s provides AMD64 test functions. - -#include "funcdata.h" -#include "textflag.h" - -TEXT ·Getpid(SB),NOSPLIT,$0 - NO_LOCAL_POINTERS - MOVQ $39, AX // getpid - SYSCALL - RET - -// func AddrOfGetpid() uintptr -TEXT ·AddrOfGetpid(SB), $0-8 - MOVQ $·Getpid(SB), AX - MOVQ AX, ret+0(FP) - RET - -TEXT ·touch(SB),NOSPLIT,$0 -start: - MOVQ 0(AX), BX // deref AX - MOVQ $39, AX // getpid - SYSCALL - JMP start - -// func AddrOfTouch() uintptr -TEXT ·AddrOfTouch(SB), $0-8 - MOVQ $·touch(SB), AX - MOVQ AX, ret+0(FP) - RET - -TEXT ·syscallLoop(SB),NOSPLIT,$0 -start: - SYSCALL - JMP start - -// func AddrOfSyscallLoop() uintptr -TEXT ·AddrOfSyscallLoop(SB), $0-8 - MOVQ $·syscallLoop(SB), AX - MOVQ AX, ret+0(FP) - RET - -TEXT ·spinLoop(SB),NOSPLIT,$0 -start: - JMP start - -// func AddrOfSpinLoop() uintptr -TEXT ·AddrOfSpinLoop(SB), $0-8 - MOVQ $·spinLoop(SB), AX - MOVQ AX, ret+0(FP) - RET - -TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 - NO_LOCAL_POINTERS - MOVQ $1, AX - MOVQ AX, X0 - MOVQ $39, AX // getpid - SYSCALL - MOVQ X0, AX - CMPQ AX, $1 - SETEQ ret+0(FP) - RET - -#define TWIDDLE_REGS() \ - NOTQ R15; \ - NOTQ R14; \ - NOTQ R13; \ - NOTQ R12; \ - NOTQ BP; \ - NOTQ BX; \ - NOTQ R11; \ - NOTQ R10; \ - NOTQ R9; \ - NOTQ R8; \ - NOTQ AX; \ - NOTQ CX; \ - NOTQ DX; \ - NOTQ SI; \ - NOTQ DI; \ - NOTQ SP; - -TEXT ·twiddleRegsSyscall(SB),NOSPLIT,$0 - TWIDDLE_REGS() - SYSCALL - RET // never reached - -// func AddrOfTwiddleRegsSyscall() uintptr -TEXT ·AddrOfTwiddleRegsSyscall(SB), $0-8 - MOVQ $·twiddleRegsSyscall(SB), AX - MOVQ AX, ret+0(FP) - RET - -TEXT ·twiddleRegsFault(SB),NOSPLIT,$0 - TWIDDLE_REGS() - JMP AX // must fault - RET // never reached - -// func AddrOfTwiddleRegsFault() uintptr -TEXT ·AddrOfTwiddleRegsFault(SB), $0-8 - MOVQ $·twiddleRegsFault(SB), AX - MOVQ AX, ret+0(FP) - RET - -#define READ_FS() BYTE $0x64; BYTE $0x48; BYTE $0x8b; BYTE $0x00; -#define READ_GS() BYTE $0x65; BYTE $0x48; BYTE $0x8b; BYTE $0x00; - -TEXT ·twiddleSegments(SB),NOSPLIT,$0 - MOVQ $0x0, AX - READ_GS() - MOVQ AX, BX - MOVQ $0x0, AX - READ_FS() - SYSCALL - RET // never reached - -// func AddrOfTwiddleSegments() uintptr -TEXT ·AddrOfTwiddleSegments(SB), $0-8 - MOVQ $·twiddleSegments(SB), AX - MOVQ AX, ret+0(FP) - RET diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go deleted file mode 100644 index 6d0ba8252..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build arm64 -// +build arm64 - -package testutil - -import ( - "fmt" - "reflect" - - "gvisor.dev/gvisor/pkg/sentry/arch" -) - -// TLSWorks is a tls test. -// -// It returns true or false. -func TLSWorks() bool - -// SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *arch.Registers, fn func()) { - regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) -} - -// SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *arch.Registers, target *uintptr) { - if target != nil { - regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer()) - } else { - regs.Regs[8] = 0 - } -} - -// RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *arch.Registers) { - regs.Pc -= 4 -} - -// SetTestRegs initializes registers to known values. -func SetTestRegs(regs *arch.Registers) { - for i := 0; i <= 30; i++ { - regs.Regs[i] = uint64(i) + 1 - } -} - -// CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *arch.Registers, full bool) (err error) { - for i := 0; i <= 30; i++ { - if need := ^uint64(i + 1); regs.Regs[i] != need { - err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) - } - } - // Check tls. - if need := ^uint64(11); regs.TPIDR_EL0 != need { - err = addRegisterMismatch(err, "tpdir_el0", regs.TPIDR_EL0, need) - } - return -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s deleted file mode 100644 index 7348c29a5..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -// test_util_arm64.s provides ARM64 test functions. - -#include "funcdata.h" -#include "textflag.h" - -#define SYS_GETPID 172 - -// This function simulates the getpid syscall. -TEXT ·Getpid(SB),NOSPLIT,$0 - NO_LOCAL_POINTERS - MOVD $SYS_GETPID, R8 - SVC - RET - -TEXT ·Touch(SB),NOSPLIT,$0 -start: - MOVD 0(R8), R1 - MOVD $SYS_GETPID, R8 // getpid - SVC - B start - -TEXT ·HaltLoop(SB),NOSPLIT,$0 -start: - HLT - B start - -// This function simulates a loop of syscall. -TEXT ·SyscallLoop(SB),NOSPLIT,$0 -start: - SVC - B start - -TEXT ·SpinLoop(SB),NOSPLIT,$0 -start: - B start - -TEXT ·TLSWorks(SB),NOSPLIT,$0-8 - NO_LOCAL_POINTERS - MOVD $0x6789, R5 - MSR R5, TPIDR_EL0 - MOVD $SYS_GETPID, R8 // getpid - SVC - MRS TPIDR_EL0, R6 - CMP R5, R6 - BNE isNaN - MOVD $1, R0 - MOVD R0, ret+0(FP) - RET -isNaN: - MOVD $0, ret+0(FP) - RET - -TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 - NO_LOCAL_POINTERS - // gc will touch fpsimd, so we should test it. - // such as in <runtime.deductSweepCredit>. - FMOVD $(9.9), F0 - MOVD $SYS_GETPID, R8 // getpid - SVC - FMOVD $(9.9), F1 - FCMPD F0, F1 - BNE isNaN - MOVD $1, R0 - MOVD R0, ret+0(FP) - RET -isNaN: - MOVD $0, ret+0(FP) - RET - -// MVN: bitwise logical NOT -// This case simulates an application that modified R0-R30. -#define TWIDDLE_REGS() \ - MVN R0, R0; \ - MVN R1, R1; \ - MVN R2, R2; \ - MVN R3, R3; \ - MVN R4, R4; \ - MVN R5, R5; \ - MVN R6, R6; \ - MVN R7, R7; \ - MVN R8, R8; \ - MVN R9, R9; \ - MVN R10, R10; \ - MVN R11, R11; \ - MVN R12, R12; \ - MVN R13, R13; \ - MVN R14, R14; \ - MVN R15, R15; \ - MVN R16, R16; \ - MVN R17, R17; \ - MVN R18_PLATFORM, R18_PLATFORM; \ - MVN R19, R19; \ - MVN R20, R20; \ - MVN R21, R21; \ - MVN R22, R22; \ - MVN R23, R23; \ - MVN R24, R24; \ - MVN R25, R25; \ - MVN R26, R26; \ - MVN R27, R27; \ - MVN g, g; \ - MVN R29, R29; \ - MVN R30, R30; - -TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 - TWIDDLE_REGS() - MSR R10, TPIDR_EL0 - // Trapped in el0_svc. - SVC - RET // never reached - -TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 - TWIDDLE_REGS() - MSR R10, TPIDR_EL0 - // Trapped in el0_ia. - // Branch to Register branches unconditionally to an address in <Rn>. - JMP (R6) // <=> br x6, must fault - RET // never reached diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go deleted file mode 100644 index 1f4a774f3..000000000 --- a/pkg/sentry/platform/kvm/virtual_map_test.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvm - -import ( - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/hostarch" -) - -type checker struct { - ok bool - accessType hostarch.AccessType -} - -func (c *checker) Containing(addr uintptr) func(virtualRegion) { - c.ok = false // Reset for below calls. - return func(vr virtualRegion) { - if vr.virtual <= addr && addr < vr.virtual+vr.length { - c.ok = true - c.accessType = vr.accessType - } - } -} - -func TestParseMaps(t *testing.T) { - c := new(checker) - - // Simple test. - if err := applyVirtualRegions(c.Containing(0)); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // MMap a new page. - addr, _, errno := unix.RawSyscall6( - unix.SYS_MMAP, 0, hostarch.PageSize, - unix.PROT_READ|unix.PROT_WRITE, - unix.MAP_ANONYMOUS|unix.MAP_PRIVATE, 0, 0) - if errno != 0 { - t.Fatalf("unexpected map error: %v", errno) - } - - // Re-parse maps. - if err := applyVirtualRegions(c.Containing(addr)); err != nil { - unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0) - t.Fatalf("unexpected error: %v", err) - } - - // Assert that it now does contain the region. - if !c.ok { - unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0) - t.Fatalf("updated map does not contain 0x%08x, expected true", addr) - } - - // Map the region as PROT_NONE. - newAddr, _, errno := unix.RawSyscall6( - unix.SYS_MMAP, addr, hostarch.PageSize, - unix.PROT_NONE, - unix.MAP_ANONYMOUS|unix.MAP_FIXED|unix.MAP_PRIVATE, 0, 0) - if errno != 0 { - t.Fatalf("unexpected map error: %v", errno) - } - if newAddr != addr { - t.Fatalf("unable to remap address: got 0x%08x, wanted 0x%08x", newAddr, addr) - } - - // Re-parse maps. - if err := applyVirtualRegions(c.Containing(addr)); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !c.ok { - t.Fatalf("final map does not contain 0x%08x, expected true", addr) - } - if c.accessType.Read || c.accessType.Write || c.accessType.Execute { - t.Fatalf("final map has incorrect permissions for 0x%08x", addr) - } - - // Unmap the region. - unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0) -} diff --git a/pkg/sentry/platform/platform_state_autogen.go b/pkg/sentry/platform/platform_state_autogen.go new file mode 100644 index 000000000..7a4d04589 --- /dev/null +++ b/pkg/sentry/platform/platform_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package platform diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD deleted file mode 100644 index d101f2f53..000000000 --- a/pkg/sentry/platform/ptrace/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "ptrace", - srcs = [ - "filters.go", - "ptrace.go", - "ptrace_amd64.go", - "ptrace_arm64.go", - "ptrace_arm64_unsafe.go", - "ptrace_unsafe.go", - "stub_amd64.s", - "stub_arm64.s", - "stub_unsafe.go", - "subprocess.go", - "subprocess_amd64.go", - "subprocess_arm64.go", - "subprocess_linux.go", - "subprocess_linux_unsafe.go", - "subprocess_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/hostarch", - "//pkg/log", - "//pkg/procid", - "//pkg/safecopy", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/arch/fpu", - "//pkg/sentry/memmap", - "//pkg/sentry/platform", - "//pkg/sentry/platform/interrupt", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_amd64_state_autogen.go new file mode 100644 index 000000000..9a5228154 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_arm64_state_autogen.go new file mode 100644 index 000000000..477523841 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe_state_autogen.go new file mode 100644 index 000000000..477523841 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_linux_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_linux_state_autogen.go new file mode 100644 index 000000000..fcfee0226 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_linux_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_linux_unsafe_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_linux_unsafe_state_autogen.go new file mode 100644 index 000000000..bf3fccfcb --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_linux_unsafe_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +//go:build linux && (amd64 || arm64) +// +build linux +// +build amd64 arm64 + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go new file mode 100644 index 000000000..1bf0526f9 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ptrace diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_unsafe_state_autogen.go new file mode 100644 index 000000000..61bdd8136 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.12 +// +build go1.12 + +package ptrace diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD deleted file mode 100644 index 1790d57c9..000000000 --- a/pkg/sentry/sighandling/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sighandling", - srcs = [ - "sighandling.go", - "sighandling_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/sighandling/sighandling_state_autogen.go b/pkg/sentry/sighandling/sighandling_state_autogen.go new file mode 100644 index 000000000..da9d96382 --- /dev/null +++ b/pkg/sentry/sighandling/sighandling_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sighandling diff --git a/pkg/sentry/sighandling/sighandling_unsafe_state_autogen.go b/pkg/sentry/sighandling/sighandling_unsafe_state_autogen.go new file mode 100644 index 000000000..da9d96382 --- /dev/null +++ b/pkg/sentry/sighandling/sighandling_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sighandling diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD deleted file mode 100644 index 7ee89a735..000000000 --- a/pkg/sentry/socket/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "socket", - srcs = ["socket.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/hostarch", - "//pkg/marshal", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD deleted file mode 100644 index b2fc84181..000000000 --- a/pkg/sentry/socket/control/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "control", - srcs = [ - "control.go", - "control_vfs2.go", - ], - imports = [ - "gvisor.dev/gvisor/pkg/sentry/fs", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - ], -) - -go_test( - name = "control_test", - size = "small", - srcs = ["control_test.go"], - library = ":control", - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", - "//pkg/sentry/socket", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/sentry/socket/control/control_state_autogen.go b/pkg/sentry/socket/control/control_state_autogen.go new file mode 100644 index 000000000..412025601 --- /dev/null +++ b/pkg/sentry/socket/control/control_state_autogen.go @@ -0,0 +1,60 @@ +// automatically generated by stateify. + +package control + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fs *RightsFiles) StateTypeName() string { + return "pkg/sentry/socket/control.RightsFiles" +} + +func (fs *RightsFiles) StateFields() []string { + return nil +} + +func (c *scmCredentials) StateTypeName() string { + return "pkg/sentry/socket/control.scmCredentials" +} + +func (c *scmCredentials) StateFields() []string { + return []string{ + "t", + "kuid", + "kgid", + } +} + +func (c *scmCredentials) beforeSave() {} + +// +checklocksignore +func (c *scmCredentials) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.t) + stateSinkObject.Save(1, &c.kuid) + stateSinkObject.Save(2, &c.kgid) +} + +func (c *scmCredentials) afterLoad() {} + +// +checklocksignore +func (c *scmCredentials) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.t) + stateSourceObject.Load(1, &c.kuid) + stateSourceObject.Load(2, &c.kgid) +} + +func (fs *RightsFilesVFS2) StateTypeName() string { + return "pkg/sentry/socket/control.RightsFilesVFS2" +} + +func (fs *RightsFilesVFS2) StateFields() []string { + return nil +} + +func init() { + state.Register((*RightsFiles)(nil)) + state.Register((*scmCredentials)(nil)) + state.Register((*RightsFilesVFS2)(nil)) +} diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go deleted file mode 100644 index 7e28a0cef..000000000 --- a/pkg/sentry/socket/control/control_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package control provides internal representations of socket control -// messages. -package control - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/socket" -) - -func TestParse(t *testing.T) { - // Craft the control message to parse. - length := linux.SizeOfControlMessageHeader + linux.SizeOfTimeval - hdr := linux.ControlMessageHeader{ - Length: uint64(length), - Level: linux.SOL_SOCKET, - Type: linux.SO_TIMESTAMP, - } - buf := make([]byte, 0, length) - buf = binary.Marshal(buf, hostarch.ByteOrder, &hdr) - ts := linux.Timeval{ - Sec: 2401, - Usec: 343, - } - buf = binary.Marshal(buf, hostarch.ByteOrder, &ts) - - cmsg, err := Parse(nil, nil, buf, 8 /* width */) - if err != nil { - t.Fatalf("Parse(_, _, %+v, _): %v", cmsg, err) - } - - want := socket.ControlMessages{ - IP: socket.IPControlMessages{ - HasTimestamp: true, - Timestamp: ts.ToNsecCapped(), - }, - } - if diff := cmp.Diff(want, cmsg); diff != "" { - t.Errorf("unexpected message parsed, (-want, +got):\n%s", diff) - } -} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD deleted file mode 100644 index 4ea89f9d0..000000000 --- a/pkg/sentry/socket/hostinet/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hostinet", - srcs = [ - "device.go", - "hostinet.go", - "save_restore.go", - "socket.go", - "socket_unsafe.go", - "socket_vfs2.go", - "sockopt_impl.go", - "stack.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fdnotifier", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/hostfd", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go new file mode 100644 index 000000000..e52b4a6d6 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package hostinet diff --git a/pkg/sentry/socket/hostinet/hostinet_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go new file mode 100644 index 000000000..519c65339 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go @@ -0,0 +1,89 @@ +// automatically generated by stateify. + +package hostinet + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/hostinet.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "family", + "stype", + "protocol", + "queue", + "fd", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +// +checklocksignore +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.family) + stateSinkObject.Save(2, &s.stype) + stateSinkObject.Save(3, &s.protocol) + stateSinkObject.Save(4, &s.queue) + stateSinkObject.Save(5, &s.fd) +} + +func (s *socketOpsCommon) afterLoad() {} + +// +checklocksignore +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.family) + stateSourceObject.Load(2, &s.stype) + stateSourceObject.Load(3, &s.protocol) + stateSourceObject.Load(4, &s.queue) + stateSourceObject.Load(5, &s.fd) +} + +func (s *socketVFS2) StateTypeName() string { + return "pkg/sentry/socket/hostinet.socketVFS2" +} + +func (s *socketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + "DentryMetadataFileDescriptionImpl", + "socketOpsCommon", + } +} + +func (s *socketVFS2) beforeSave() {} + +// +checklocksignore +func (s *socketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.LockFD) + stateSinkObject.Save(3, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(4, &s.socketOpsCommon) +} + +func (s *socketVFS2) afterLoad() {} + +// +checklocksignore +func (s *socketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.LockFD) + stateSourceObject.Load(3, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(4, &s.socketOpsCommon) +} + +func init() { + state.Register((*socketOpsCommon)(nil)) + state.Register((*socketVFS2)(nil)) +} diff --git a/pkg/sentry/socket/hostinet/hostinet_unsafe_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_unsafe_state_autogen.go new file mode 100644 index 000000000..b0a59ba93 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostinet diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD deleted file mode 100644 index 608474fa1..000000000 --- a/pkg/sentry/socket/netfilter/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "netfilter", - srcs = [ - "extensions.go", - "ipv4.go", - "ipv6.go", - "netfilter.go", - "owner_matcher.go", - "targets.go", - "tcp_matcher.go", - "udp_matcher.go", - ], - marshal = True, - # This target depends on netstack and should only be used by epsocket, - # which is allowed to depend on netstack. - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/sentry/socket/netfilter/netfilter_abi_autogen_unsafe.go b/pkg/sentry/socket/netfilter/netfilter_abi_autogen_unsafe.go new file mode 100644 index 000000000..786b62935 --- /dev/null +++ b/pkg/sentry/socket/netfilter/netfilter_abi_autogen_unsafe.go @@ -0,0 +1,150 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package netfilter + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*linux.NFNATRange)(nil) +var _ marshal.Marshallable = (*linux.XTEntryTarget)(nil) +var _ marshal.Marshallable = (*nfNATTarget)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (n *nfNATTarget) SizeBytes() int { + return 0 + + (*linux.XTEntryTarget)(nil).SizeBytes() + + (*linux.NFNATRange)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (n *nfNATTarget) MarshalBytes(dst []byte) { + n.Target.MarshalBytes(dst[:n.Target.SizeBytes()]) + dst = dst[n.Target.SizeBytes():] + n.Range.MarshalBytes(dst[:n.Range.SizeBytes()]) + dst = dst[n.Range.SizeBytes():] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (n *nfNATTarget) UnmarshalBytes(src []byte) { + n.Target.UnmarshalBytes(src[:n.Target.SizeBytes()]) + src = src[n.Target.SizeBytes():] + n.Range.UnmarshalBytes(src[:n.Range.SizeBytes()]) + src = src[n.Range.SizeBytes():] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (n *nfNATTarget) Packed() bool { + return n.Range.Packed() && n.Target.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (n *nfNATTarget) MarshalUnsafe(dst []byte) { + if n.Range.Packed() && n.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(n), uintptr(n.SizeBytes())) + } else { + // Type nfNATTarget doesn't have a packed layout in memory, fallback to MarshalBytes. + n.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (n *nfNATTarget) UnmarshalUnsafe(src []byte) { + if n.Range.Packed() && n.Target.Packed() { + gohacks.Memmove(unsafe.Pointer(n), unsafe.Pointer(&src[0]), uintptr(n.SizeBytes())) + } else { + // Type nfNATTarget doesn't have a packed layout in memory, fallback to UnmarshalBytes. + n.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (n *nfNATTarget) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !n.Range.Packed() && n.Target.Packed() { + // Type nfNATTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + n.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (n *nfNATTarget) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return n.CopyOutN(cc, addr, n.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (n *nfNATTarget) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !n.Range.Packed() && n.Target.Packed() { + // Type nfNATTarget doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(n.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + n.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (n *nfNATTarget) WriteTo(writer io.Writer) (int64, error) { + if !n.Range.Packed() && n.Target.Packed() { + // Type nfNATTarget doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, n.SizeBytes()) + n.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(n))) + hdr.Len = n.SizeBytes() + hdr.Cap = n.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that n + // must live until the use above. + runtime.KeepAlive(n) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/socket/netfilter/netfilter_state_autogen.go b/pkg/sentry/socket/netfilter/netfilter_state_autogen.go new file mode 100644 index 000000000..6e95d89a4 --- /dev/null +++ b/pkg/sentry/socket/netfilter/netfilter_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package netfilter diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD deleted file mode 100644 index 9710a15ee..000000000 --- a/pkg/sentry/socket/netlink/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "netlink", - srcs = [ - "message.go", - "provider.go", - "provider_vfs2.go", - "socket.go", - "socket_vfs2.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/abi/linux/errno", - "//pkg/bits", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "netlink_test", - size = "small", - srcs = [ - "message_test.go", - ], - deps = [ - ":netlink", - "//pkg/abi/linux", - "//pkg/marshal", - "//pkg/marshal/primitive", - ], -) diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go deleted file mode 100644 index 968968469..000000000 --- a/pkg/sentry/socket/netlink/message_test.go +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package message_test - -import ( - "bytes" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" - "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/sentry/socket/netlink" -) - -type dummyNetlinkMsg struct { - marshal.StubMarshallable - Foo uint16 -} - -func (*dummyNetlinkMsg) SizeBytes() int { - return 2 -} - -func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { - p := primitive.Uint16(m.Foo) - p.MarshalUnsafe(dst) -} - -func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { - var p primitive.Uint16 - p.UnmarshalUnsafe(src) - m.Foo = uint16(p) -} - -func TestParseMessage(t *testing.T) { - tests := []struct { - desc string - input []byte - - header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg - restLen int - ok bool - }{ - { - desc: "valid", - input: []byte{ - 0x14, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - header: linux.NetlinkMessageHeader{ - Length: 20, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "valid with next message", - input: []byte{ - 0x14, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - 0xFF, // Next message (rest) - }, - header: linux.NetlinkMessageHeader{ - Length: 20, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 1, - ok: true, - }, - { - desc: "valid for last message without padding", - input: []byte{ - 0x12, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, // Data message - }, - header: linux.NetlinkMessageHeader{ - Length: 18, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "valid for last message not to be aligned", - input: []byte{ - 0x13, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, // Data message - 0x00, // Excessive 1 byte permitted at end - }, - header: linux.NetlinkMessageHeader{ - Length: 19, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "header.Length too short", - input: []byte{ - 0x04, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - ok: false, - }, - { - desc: "header.Length too long", - input: []byte{ - 0xFF, 0xFF, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - ok: false, - }, - { - desc: "header incomplete", - input: []byte{ - 0x04, 0x00, 0x00, 0x00, // Length - }, - ok: false, - }, - { - desc: "empty message", - input: []byte{}, - ok: false, - }, - } - for _, test := range tests { - msg, rest, ok := netlink.ParseMessage(test.input) - if ok != test.ok { - t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) - continue - } - if !test.ok { - continue - } - if !reflect.DeepEqual(msg.Header(), test.header) { - t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) - } - - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) - if !dataOk { - t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { - t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) - } - - if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { - t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) - } - } -} - -func TestAttrView(t *testing.T) { - tests := []struct { - desc string - input []byte - - // Outputs for ParseFirst. - hdr linux.NetlinkAttrHeader - value []byte - restLen int - ok bool - - // Outputs for Empty. - isEmpty bool - }{ - { - desc: "valid", - input: []byte{ - 0x06, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding - }, - hdr: linux.NetlinkAttrHeader{ - Length: 6, - Type: 1, - }, - value: []byte{0x30, 0x31}, - restLen: 0, - ok: true, - isEmpty: false, - }, - { - desc: "at alignment", - input: []byte{ - 0x08, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - hdr: linux.NetlinkAttrHeader{ - Length: 8, - Type: 1, - }, - value: []byte{0x30, 0x31, 0x32, 0x33}, - restLen: 0, - ok: true, - isEmpty: false, - }, - { - desc: "at alignment with rest data", - input: []byte{ - 0x08, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - 0xFF, 0xFE, // Rest data - }, - hdr: linux.NetlinkAttrHeader{ - Length: 8, - Type: 1, - }, - value: []byte{0x30, 0x31, 0x32, 0x33}, - restLen: 2, - ok: true, - isEmpty: false, - }, - { - desc: "hdr.Length too long", - input: []byte{ - 0xFF, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - ok: false, - isEmpty: false, - }, - { - desc: "hdr.Length too short", - input: []byte{ - 0x01, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - ok: false, - isEmpty: false, - }, - { - desc: "empty", - input: []byte{}, - ok: false, - isEmpty: true, - }, - } - for _, test := range tests { - attrs := netlink.AttrsView(test.input) - - // Test ParseFirst(). - hdr, value, rest, ok := attrs.ParseFirst() - if ok != test.ok { - t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) - } else if test.ok { - if !reflect.DeepEqual(hdr, test.hdr) { - t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) - } - if !bytes.Equal(value, test.value) { - t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) - } - if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { - t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) - } - } - - // Test Empty(). - if got, want := attrs.Empty(), test.isEmpty; got != want { - t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) - } - } -} diff --git a/pkg/sentry/socket/netlink/netlink_state_autogen.go b/pkg/sentry/socket/netlink/netlink_state_autogen.go new file mode 100644 index 000000000..29c549880 --- /dev/null +++ b/pkg/sentry/socket/netlink/netlink_state_autogen.go @@ -0,0 +1,149 @@ +// automatically generated by stateify. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *Socket) StateTypeName() string { + return "pkg/sentry/socket/netlink.Socket" +} + +func (s *Socket) StateFields() []string { + return []string{ + "socketOpsCommon", + } +} + +func (s *Socket) beforeSave() {} + +// +checklocksignore +func (s *Socket) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketOpsCommon) +} + +func (s *Socket) afterLoad() {} + +// +checklocksignore +func (s *Socket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketOpsCommon) +} + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/netlink.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "ports", + "protocol", + "skType", + "ep", + "connection", + "bound", + "portID", + "sendBufferSize", + "filter", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +// +checklocksignore +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.ports) + stateSinkObject.Save(2, &s.protocol) + stateSinkObject.Save(3, &s.skType) + stateSinkObject.Save(4, &s.ep) + stateSinkObject.Save(5, &s.connection) + stateSinkObject.Save(6, &s.bound) + stateSinkObject.Save(7, &s.portID) + stateSinkObject.Save(8, &s.sendBufferSize) + stateSinkObject.Save(9, &s.filter) +} + +func (s *socketOpsCommon) afterLoad() {} + +// +checklocksignore +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.ports) + stateSourceObject.Load(2, &s.protocol) + stateSourceObject.Load(3, &s.skType) + stateSourceObject.Load(4, &s.ep) + stateSourceObject.Load(5, &s.connection) + stateSourceObject.Load(6, &s.bound) + stateSourceObject.Load(7, &s.portID) + stateSourceObject.Load(8, &s.sendBufferSize) + stateSourceObject.Load(9, &s.filter) +} + +func (k *kernelSCM) StateTypeName() string { + return "pkg/sentry/socket/netlink.kernelSCM" +} + +func (k *kernelSCM) StateFields() []string { + return []string{} +} + +func (k *kernelSCM) beforeSave() {} + +// +checklocksignore +func (k *kernelSCM) StateSave(stateSinkObject state.Sink) { + k.beforeSave() +} + +func (k *kernelSCM) afterLoad() {} + +// +checklocksignore +func (k *kernelSCM) StateLoad(stateSourceObject state.Source) { +} + +func (s *SocketVFS2) StateTypeName() string { + return "pkg/sentry/socket/netlink.SocketVFS2" +} + +func (s *SocketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "socketOpsCommon", + } +} + +func (s *SocketVFS2) beforeSave() {} + +// +checklocksignore +func (s *SocketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &s.LockFD) + stateSinkObject.Save(4, &s.socketOpsCommon) +} + +func (s *SocketVFS2) afterLoad() {} + +// +checklocksignore +func (s *SocketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &s.LockFD) + stateSourceObject.Load(4, &s.socketOpsCommon) +} + +func init() { + state.Register((*Socket)(nil)) + state.Register((*socketOpsCommon)(nil)) + state.Register((*kernelSCM)(nil)) + state.Register((*SocketVFS2)(nil)) +} diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD deleted file mode 100644 index 3a22923d8..000000000 --- a/pkg/sentry/socket/netlink/port/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "port", - srcs = ["port.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sync"], -) - -go_test( - name = "port_test", - srcs = ["port_test.go"], - library = ":port", -) diff --git a/pkg/sentry/socket/netlink/port/port_state_autogen.go b/pkg/sentry/socket/netlink/port/port_state_autogen.go new file mode 100644 index 000000000..b22471899 --- /dev/null +++ b/pkg/sentry/socket/netlink/port/port_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package port + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (m *Manager) StateTypeName() string { + return "pkg/sentry/socket/netlink/port.Manager" +} + +func (m *Manager) StateFields() []string { + return []string{ + "ports", + } +} + +func (m *Manager) beforeSave() {} + +// +checklocksignore +func (m *Manager) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.ports) +} + +func (m *Manager) afterLoad() {} + +// +checklocksignore +func (m *Manager) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.ports) +} + +func init() { + state.Register((*Manager)(nil)) +} diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go deleted file mode 100644 index 516f6cd6c..000000000 --- a/pkg/sentry/socket/netlink/port/port_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package port - -import ( - "testing" -) - -func TestAllocateHint(t *testing.T) { - m := New() - - // We can get the hint port. - p, ok := m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } - - // Hint is taken. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p == 1 { - t.Errorf("m.Allocate(0, 1) got 1 want anything else") - } - - // Hint is available for a different protocol. - p, ok = m.Allocate(1, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(1, 1) got %d want 1", p) - } - - m.Release(0, 1) - - // Hint is available again after release. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } -} - -func TestAllocateExhausted(t *testing.T) { - m := New() - - // Fill all ports (0 is already reserved). - for i := int32(1); i < maxPorts; i++ { - p, ok := m.Allocate(0, i) - if !ok { - t.Fatalf("m.Allocate got !ok want ok") - } - if p != i { - t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i) - } - } - - // Now no more can be allocated. - p, ok := m.Allocate(0, 1) - if ok { - t.Errorf("m.Allocate got %d, ok want !ok", p) - } -} diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD deleted file mode 100644 index c6c04b4e3..000000000 --- a/pkg/sentry/socket/netlink/route/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "route", - srcs = [ - "protocol.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/marshal/primitive", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/netlink/route/route_state_autogen.go b/pkg/sentry/socket/netlink/route/route_state_autogen.go new file mode 100644 index 000000000..c4a94ab49 --- /dev/null +++ b/pkg/sentry/socket/netlink/route/route_state_autogen.go @@ -0,0 +1,32 @@ +// automatically generated by stateify. + +package route + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *Protocol) StateTypeName() string { + return "pkg/sentry/socket/netlink/route.Protocol" +} + +func (p *Protocol) StateFields() []string { + return []string{} +} + +func (p *Protocol) beforeSave() {} + +// +checklocksignore +func (p *Protocol) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *Protocol) afterLoad() {} + +// +checklocksignore +func (p *Protocol) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*Protocol)(nil)) +} diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD deleted file mode 100644 index b6434923c..000000000 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "uevent", - srcs = ["protocol.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/kernel", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - ], -) diff --git a/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go b/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go new file mode 100644 index 000000000..f45d63d9a --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go @@ -0,0 +1,32 @@ +// automatically generated by stateify. + +package uevent + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *Protocol) StateTypeName() string { + return "pkg/sentry/socket/netlink/uevent.Protocol" +} + +func (p *Protocol) StateFields() []string { + return []string{} +} + +func (p *Protocol) beforeSave() {} + +// +checklocksignore +func (p *Protocol) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *Protocol) afterLoad() {} + +// +checklocksignore +func (p *Protocol) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*Protocol)(nil)) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD deleted file mode 100644 index e347442e7..000000000 --- a/pkg/sentry/socket/netstack/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "netstack", - srcs = [ - "device.go", - "netstack.go", - "netstack_vfs2.go", - "provider.go", - "provider_vfs2.go", - "save_restore.go", - "stack.go", - "tun.go", - ], - visibility = [ - "//pkg/sentry:internal", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/abi/linux/errno", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/metric", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netfilter", - "//pkg/sentry/unimpl", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go new file mode 100644 index 000000000..335437f04 --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go @@ -0,0 +1,148 @@ +// automatically generated by stateify. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *SocketOperations) StateTypeName() string { + return "pkg/sentry/socket/netstack.SocketOperations" +} + +func (s *SocketOperations) StateFields() []string { + return []string{ + "socketOpsCommon", + } +} + +func (s *SocketOperations) beforeSave() {} + +// +checklocksignore +func (s *SocketOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketOpsCommon) +} + +func (s *SocketOperations) afterLoad() {} + +// +checklocksignore +func (s *SocketOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketOpsCommon) +} + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/netstack.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "Queue", + "family", + "Endpoint", + "skType", + "protocol", + "sockOptTimestamp", + "timestampValid", + "timestampNS", + "sockOptInq", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +// +checklocksignore +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.Queue) + stateSinkObject.Save(2, &s.family) + stateSinkObject.Save(3, &s.Endpoint) + stateSinkObject.Save(4, &s.skType) + stateSinkObject.Save(5, &s.protocol) + stateSinkObject.Save(6, &s.sockOptTimestamp) + stateSinkObject.Save(7, &s.timestampValid) + stateSinkObject.Save(8, &s.timestampNS) + stateSinkObject.Save(9, &s.sockOptInq) +} + +func (s *socketOpsCommon) afterLoad() {} + +// +checklocksignore +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.Queue) + stateSourceObject.Load(2, &s.family) + stateSourceObject.Load(3, &s.Endpoint) + stateSourceObject.Load(4, &s.skType) + stateSourceObject.Load(5, &s.protocol) + stateSourceObject.Load(6, &s.sockOptTimestamp) + stateSourceObject.Load(7, &s.timestampValid) + stateSourceObject.Load(8, &s.timestampNS) + stateSourceObject.Load(9, &s.sockOptInq) +} + +func (s *SocketVFS2) StateTypeName() string { + return "pkg/sentry/socket/netstack.SocketVFS2" +} + +func (s *SocketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "socketOpsCommon", + } +} + +func (s *SocketVFS2) beforeSave() {} + +// +checklocksignore +func (s *SocketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &s.LockFD) + stateSinkObject.Save(4, &s.socketOpsCommon) +} + +func (s *SocketVFS2) afterLoad() {} + +// +checklocksignore +func (s *SocketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &s.LockFD) + stateSourceObject.Load(4, &s.socketOpsCommon) +} + +func (s *Stack) StateTypeName() string { + return "pkg/sentry/socket/netstack.Stack" +} + +func (s *Stack) StateFields() []string { + return []string{} +} + +func (s *Stack) beforeSave() {} + +// +checklocksignore +func (s *Stack) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +// +checklocksignore +func (s *Stack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.AfterLoad(s.afterLoad) +} + +func init() { + state.Register((*SocketOperations)(nil)) + state.Register((*socketOpsCommon)(nil)) + state.Register((*SocketVFS2)(nil)) + state.Register((*Stack)(nil)) +} diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go new file mode 100644 index 000000000..d050093f6 --- /dev/null +++ b/pkg/sentry/socket/socket_state_autogen.go @@ -0,0 +1,98 @@ +// automatically generated by stateify. + +package socket + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *IPControlMessages) StateTypeName() string { + return "pkg/sentry/socket.IPControlMessages" +} + +func (i *IPControlMessages) StateFields() []string { + return []string{ + "HasTimestamp", + "Timestamp", + "HasInq", + "Inq", + "HasTOS", + "TOS", + "HasTClass", + "TClass", + "HasIPPacketInfo", + "PacketInfo", + "OriginalDstAddress", + "SockErr", + } +} + +func (i *IPControlMessages) beforeSave() {} + +// +checklocksignore +func (i *IPControlMessages) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.HasTimestamp) + stateSinkObject.Save(1, &i.Timestamp) + stateSinkObject.Save(2, &i.HasInq) + stateSinkObject.Save(3, &i.Inq) + stateSinkObject.Save(4, &i.HasTOS) + stateSinkObject.Save(5, &i.TOS) + stateSinkObject.Save(6, &i.HasTClass) + stateSinkObject.Save(7, &i.TClass) + stateSinkObject.Save(8, &i.HasIPPacketInfo) + stateSinkObject.Save(9, &i.PacketInfo) + stateSinkObject.Save(10, &i.OriginalDstAddress) + stateSinkObject.Save(11, &i.SockErr) +} + +func (i *IPControlMessages) afterLoad() {} + +// +checklocksignore +func (i *IPControlMessages) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.HasTimestamp) + stateSourceObject.Load(1, &i.Timestamp) + stateSourceObject.Load(2, &i.HasInq) + stateSourceObject.Load(3, &i.Inq) + stateSourceObject.Load(4, &i.HasTOS) + stateSourceObject.Load(5, &i.TOS) + stateSourceObject.Load(6, &i.HasTClass) + stateSourceObject.Load(7, &i.TClass) + stateSourceObject.Load(8, &i.HasIPPacketInfo) + stateSourceObject.Load(9, &i.PacketInfo) + stateSourceObject.Load(10, &i.OriginalDstAddress) + stateSourceObject.Load(11, &i.SockErr) +} + +func (to *SendReceiveTimeout) StateTypeName() string { + return "pkg/sentry/socket.SendReceiveTimeout" +} + +func (to *SendReceiveTimeout) StateFields() []string { + return []string{ + "send", + "recv", + } +} + +func (to *SendReceiveTimeout) beforeSave() {} + +// +checklocksignore +func (to *SendReceiveTimeout) StateSave(stateSinkObject state.Sink) { + to.beforeSave() + stateSinkObject.Save(0, &to.send) + stateSinkObject.Save(1, &to.recv) +} + +func (to *SendReceiveTimeout) afterLoad() {} + +// +checklocksignore +func (to *SendReceiveTimeout) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &to.send) + stateSourceObject.Load(1, &to.recv) +} + +func init() { + state.Register((*IPControlMessages)(nil)) + state.Register((*SendReceiveTimeout)(nil)) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD deleted file mode 100644 index 7b546c04d..000000000 --- a/pkg/sentry/socket/unix/BUILD +++ /dev/null @@ -1,70 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "socket_refs", - out = "socket_refs.go", - package = "unix", - prefix = "socketOperations", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "SocketOperations", - }, -) - -go_template_instance( - name = "socket_vfs2_refs", - out = "socket_vfs2_refs.go", - package = "unix", - prefix = "socketVFS2", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "SocketVFS2", - }, -) - -go_library( - name = "unix", - srcs = [ - "device.go", - "io.go", - "socket_refs.go", - "socket_vfs2_refs.go", - "unix.go", - "unix_vfs2.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/netstack", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/unix/socket_refs.go b/pkg/sentry/socket/unix/socket_refs.go new file mode 100644 index 000000000..61c6bd17c --- /dev/null +++ b/pkg/sentry/socket/unix/socket_refs.go @@ -0,0 +1,140 @@ +package unix + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const socketOperationsenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var socketOperationsobj *SocketOperations + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type socketOperationsRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *socketOperationsRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *socketOperationsRefs) RefType() string { + return fmt.Sprintf("%T", socketOperationsobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *socketOperationsRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *socketOperationsRefs) LogRefs() bool { + return socketOperationsenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *socketOperationsRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *socketOperationsRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if socketOperationsenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *socketOperationsRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if socketOperationsenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *socketOperationsRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if socketOperationsenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *socketOperationsRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/socket/unix/socket_vfs2_refs.go b/pkg/sentry/socket/unix/socket_vfs2_refs.go new file mode 100644 index 000000000..f6ef581d8 --- /dev/null +++ b/pkg/sentry/socket/unix/socket_vfs2_refs.go @@ -0,0 +1,140 @@ +package unix + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const socketVFS2enableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var socketVFS2obj *SocketVFS2 + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type socketVFS2Refs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *socketVFS2Refs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *socketVFS2Refs) RefType() string { + return fmt.Sprintf("%T", socketVFS2obj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *socketVFS2Refs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *socketVFS2Refs) LogRefs() bool { + return socketVFS2enableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *socketVFS2Refs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *socketVFS2Refs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if socketVFS2enableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *socketVFS2Refs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if socketVFS2enableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *socketVFS2Refs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if socketVFS2enableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *socketVFS2Refs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD deleted file mode 100644 index 0d11bb251..000000000 --- a/pkg/sentry/socket/unix/transport/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "transport_message_list", - out = "transport_message_list.go", - package = "transport", - prefix = "message", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*message", - "Linker": "*message", - }, -) - -go_template_instance( - name = "queue_refs", - out = "queue_refs.go", - package = "transport", - prefix = "queue", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "queue", - }, -) - -go_library( - name = "transport", - srcs = [ - "connectioned.go", - "connectioned_state.go", - "connectionless.go", - "connectionless_state.go", - "queue.go", - "queue_refs.go", - "transport_message_list.go", - "unix.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/ilist", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sentry/inet", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/transport/queue_refs.go b/pkg/sentry/socket/unix/transport/queue_refs.go new file mode 100644 index 000000000..9e6f52616 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/queue_refs.go @@ -0,0 +1,140 @@ +package transport + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const queueenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var queueobj *queue + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type queueRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *queueRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *queueRefs) RefType() string { + return fmt.Sprintf("%T", queueobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *queueRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *queueRefs) LogRefs() bool { + return queueenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *queueRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *queueRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if queueenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *queueRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if queueenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *queueRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if queueenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *queueRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go new file mode 100644 index 000000000..945163cc1 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_message_list.go @@ -0,0 +1,221 @@ +package transport + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type messageElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (messageElementMapper) linkerFor(elem *message) *message { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type messageList struct { + head *message + tail *message +} + +// Reset resets list l to the empty state. +func (l *messageList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *messageList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *messageList) Front() *message { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *messageList) Back() *message { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *messageList) Len() (count int) { + for e := l.Front(); e != nil; e = (messageElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *messageList) PushFront(e *message) { + linker := messageElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + messageElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *messageList) PushBack(e *message) { + linker := messageElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *messageList) PushBackList(m *messageList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(m.head) + messageElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *messageList) InsertAfter(b, e *message) { + bLinker := messageElementMapper{}.linkerFor(b) + eLinker := messageElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + messageElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *messageList) InsertBefore(a, e *message) { + aLinker := messageElementMapper{}.linkerFor(a) + eLinker := messageElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + messageElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *messageList) Remove(e *message) { + linker := messageElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + messageElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + messageElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type messageEntry struct { + next *message + prev *message +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *messageEntry) Next() *message { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *messageEntry) Prev() *message { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *messageEntry) SetNext(elem *message) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *messageEntry) SetPrev(elem *message) { + e.prev = elem +} diff --git a/pkg/sentry/socket/unix/transport/transport_state_autogen.go b/pkg/sentry/socket/unix/transport/transport_state_autogen.go new file mode 100644 index 000000000..730a89b9c --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_state_autogen.go @@ -0,0 +1,399 @@ +// automatically generated by stateify. + +package transport + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *connectionedEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.connectionedEndpoint" +} + +func (e *connectionedEndpoint) StateFields() []string { + return []string{ + "baseEndpoint", + "id", + "idGenerator", + "stype", + "acceptedChan", + } +} + +func (e *connectionedEndpoint) beforeSave() {} + +// +checklocksignore +func (e *connectionedEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var acceptedChanValue []*connectionedEndpoint + acceptedChanValue = e.saveAcceptedChan() + stateSinkObject.SaveValue(4, acceptedChanValue) + stateSinkObject.Save(0, &e.baseEndpoint) + stateSinkObject.Save(1, &e.id) + stateSinkObject.Save(2, &e.idGenerator) + stateSinkObject.Save(3, &e.stype) +} + +// +checklocksignore +func (e *connectionedEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.baseEndpoint) + stateSourceObject.Load(1, &e.id) + stateSourceObject.Load(2, &e.idGenerator) + stateSourceObject.Load(3, &e.stype) + stateSourceObject.LoadValue(4, new([]*connectionedEndpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*connectionedEndpoint)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (e *connectionlessEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.connectionlessEndpoint" +} + +func (e *connectionlessEndpoint) StateFields() []string { + return []string{ + "baseEndpoint", + } +} + +func (e *connectionlessEndpoint) beforeSave() {} + +// +checklocksignore +func (e *connectionlessEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.baseEndpoint) +} + +// +checklocksignore +func (e *connectionlessEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.baseEndpoint) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (q *queue) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.queue" +} + +func (q *queue) StateFields() []string { + return []string{ + "queueRefs", + "ReaderQueue", + "WriterQueue", + "closed", + "unread", + "used", + "limit", + "dataList", + } +} + +func (q *queue) beforeSave() {} + +// +checklocksignore +func (q *queue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.queueRefs) + stateSinkObject.Save(1, &q.ReaderQueue) + stateSinkObject.Save(2, &q.WriterQueue) + stateSinkObject.Save(3, &q.closed) + stateSinkObject.Save(4, &q.unread) + stateSinkObject.Save(5, &q.used) + stateSinkObject.Save(6, &q.limit) + stateSinkObject.Save(7, &q.dataList) +} + +func (q *queue) afterLoad() {} + +// +checklocksignore +func (q *queue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.queueRefs) + stateSourceObject.Load(1, &q.ReaderQueue) + stateSourceObject.Load(2, &q.WriterQueue) + stateSourceObject.Load(3, &q.closed) + stateSourceObject.Load(4, &q.unread) + stateSourceObject.Load(5, &q.used) + stateSourceObject.Load(6, &q.limit) + stateSourceObject.Load(7, &q.dataList) +} + +func (r *queueRefs) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.queueRefs" +} + +func (r *queueRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *queueRefs) beforeSave() {} + +// +checklocksignore +func (r *queueRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *queueRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (l *messageList) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.messageList" +} + +func (l *messageList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *messageList) beforeSave() {} + +// +checklocksignore +func (l *messageList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *messageList) afterLoad() {} + +// +checklocksignore +func (l *messageList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *messageEntry) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.messageEntry" +} + +func (e *messageEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *messageEntry) beforeSave() {} + +// +checklocksignore +func (e *messageEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *messageEntry) afterLoad() {} + +// +checklocksignore +func (e *messageEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (c *ControlMessages) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.ControlMessages" +} + +func (c *ControlMessages) StateFields() []string { + return []string{ + "Rights", + "Credentials", + } +} + +func (c *ControlMessages) beforeSave() {} + +// +checklocksignore +func (c *ControlMessages) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.Rights) + stateSinkObject.Save(1, &c.Credentials) +} + +func (c *ControlMessages) afterLoad() {} + +// +checklocksignore +func (c *ControlMessages) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.Rights) + stateSourceObject.Load(1, &c.Credentials) +} + +func (m *message) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.message" +} + +func (m *message) StateFields() []string { + return []string{ + "messageEntry", + "Data", + "Control", + "Address", + } +} + +func (m *message) beforeSave() {} + +// +checklocksignore +func (m *message) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.messageEntry) + stateSinkObject.Save(1, &m.Data) + stateSinkObject.Save(2, &m.Control) + stateSinkObject.Save(3, &m.Address) +} + +func (m *message) afterLoad() {} + +// +checklocksignore +func (m *message) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.messageEntry) + stateSourceObject.Load(1, &m.Data) + stateSourceObject.Load(2, &m.Control) + stateSourceObject.Load(3, &m.Address) +} + +func (q *queueReceiver) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.queueReceiver" +} + +func (q *queueReceiver) StateFields() []string { + return []string{ + "readQueue", + } +} + +func (q *queueReceiver) beforeSave() {} + +// +checklocksignore +func (q *queueReceiver) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.readQueue) +} + +func (q *queueReceiver) afterLoad() {} + +// +checklocksignore +func (q *queueReceiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.readQueue) +} + +func (q *streamQueueReceiver) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.streamQueueReceiver" +} + +func (q *streamQueueReceiver) StateFields() []string { + return []string{ + "queueReceiver", + "buffer", + "control", + "addr", + } +} + +func (q *streamQueueReceiver) beforeSave() {} + +// +checklocksignore +func (q *streamQueueReceiver) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.queueReceiver) + stateSinkObject.Save(1, &q.buffer) + stateSinkObject.Save(2, &q.control) + stateSinkObject.Save(3, &q.addr) +} + +func (q *streamQueueReceiver) afterLoad() {} + +// +checklocksignore +func (q *streamQueueReceiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.queueReceiver) + stateSourceObject.Load(1, &q.buffer) + stateSourceObject.Load(2, &q.control) + stateSourceObject.Load(3, &q.addr) +} + +func (e *connectedEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.connectedEndpoint" +} + +func (e *connectedEndpoint) StateFields() []string { + return []string{ + "endpoint", + "writeQueue", + } +} + +func (e *connectedEndpoint) beforeSave() {} + +// +checklocksignore +func (e *connectedEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.endpoint) + stateSinkObject.Save(1, &e.writeQueue) +} + +func (e *connectedEndpoint) afterLoad() {} + +// +checklocksignore +func (e *connectedEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.endpoint) + stateSourceObject.Load(1, &e.writeQueue) +} + +func (e *baseEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.baseEndpoint" +} + +func (e *baseEndpoint) StateFields() []string { + return []string{ + "Queue", + "DefaultSocketOptionsHandler", + "receiver", + "connected", + "path", + "ops", + } +} + +func (e *baseEndpoint) beforeSave() {} + +// +checklocksignore +func (e *baseEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Queue) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.receiver) + stateSinkObject.Save(3, &e.connected) + stateSinkObject.Save(4, &e.path) + stateSinkObject.Save(5, &e.ops) +} + +func (e *baseEndpoint) afterLoad() {} + +// +checklocksignore +func (e *baseEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Queue) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.receiver) + stateSourceObject.Load(3, &e.connected) + stateSourceObject.Load(4, &e.path) + stateSourceObject.Load(5, &e.ops) +} + +func init() { + state.Register((*connectionedEndpoint)(nil)) + state.Register((*connectionlessEndpoint)(nil)) + state.Register((*queue)(nil)) + state.Register((*queueRefs)(nil)) + state.Register((*messageList)(nil)) + state.Register((*messageEntry)(nil)) + state.Register((*ControlMessages)(nil)) + state.Register((*message)(nil)) + state.Register((*queueReceiver)(nil)) + state.Register((*streamQueueReceiver)(nil)) + state.Register((*connectedEndpoint)(nil)) + state.Register((*baseEndpoint)(nil)) +} diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go new file mode 100644 index 000000000..e6169dfad --- /dev/null +++ b/pkg/sentry/socket/unix/unix_state_autogen.go @@ -0,0 +1,168 @@ +// automatically generated by stateify. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *socketOperationsRefs) StateTypeName() string { + return "pkg/sentry/socket/unix.socketOperationsRefs" +} + +func (r *socketOperationsRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *socketOperationsRefs) beforeSave() {} + +// +checklocksignore +func (r *socketOperationsRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *socketOperationsRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (r *socketVFS2Refs) StateTypeName() string { + return "pkg/sentry/socket/unix.socketVFS2Refs" +} + +func (r *socketVFS2Refs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *socketVFS2Refs) beforeSave() {} + +// +checklocksignore +func (r *socketVFS2Refs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *socketVFS2Refs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (s *SocketOperations) StateTypeName() string { + return "pkg/sentry/socket/unix.SocketOperations" +} + +func (s *SocketOperations) StateFields() []string { + return []string{ + "socketOperationsRefs", + "socketOpsCommon", + } +} + +func (s *SocketOperations) beforeSave() {} + +// +checklocksignore +func (s *SocketOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketOperationsRefs) + stateSinkObject.Save(1, &s.socketOpsCommon) +} + +func (s *SocketOperations) afterLoad() {} + +// +checklocksignore +func (s *SocketOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketOperationsRefs) + stateSourceObject.Load(1, &s.socketOpsCommon) +} + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/unix.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "ep", + "stype", + "abstractName", + "abstractNamespace", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +// +checklocksignore +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.ep) + stateSinkObject.Save(2, &s.stype) + stateSinkObject.Save(3, &s.abstractName) + stateSinkObject.Save(4, &s.abstractNamespace) +} + +func (s *socketOpsCommon) afterLoad() {} + +// +checklocksignore +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.ep) + stateSourceObject.Load(2, &s.stype) + stateSourceObject.Load(3, &s.abstractName) + stateSourceObject.Load(4, &s.abstractNamespace) +} + +func (s *SocketVFS2) StateTypeName() string { + return "pkg/sentry/socket/unix.SocketVFS2" +} + +func (s *SocketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "socketVFS2Refs", + "socketOpsCommon", + } +} + +func (s *SocketVFS2) beforeSave() {} + +// +checklocksignore +func (s *SocketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &s.LockFD) + stateSinkObject.Save(4, &s.socketVFS2Refs) + stateSinkObject.Save(5, &s.socketOpsCommon) +} + +func (s *SocketVFS2) afterLoad() {} + +// +checklocksignore +func (s *SocketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &s.LockFD) + stateSourceObject.Load(4, &s.socketVFS2Refs) + stateSourceObject.Load(5, &s.socketOpsCommon) +} + +func init() { + state.Register((*socketOperationsRefs)(nil)) + state.Register((*socketVFS2Refs)(nil)) + state.Register((*SocketOperations)(nil)) + state.Register((*socketOpsCommon)(nil)) + state.Register((*SocketVFS2)(nil)) +} diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD deleted file mode 100644 index 7f02807c5..000000000 --- a/pkg/sentry/state/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "state", - srcs = [ - "state.go", - "state_metadata.go", - "state_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/time", - "//pkg/sentry/vfs", - "//pkg/sentry/watchdog", - "//pkg/state/statefile", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/state/state_state_autogen.go b/pkg/sentry/state/state_state_autogen.go new file mode 100644 index 000000000..305f03f52 --- /dev/null +++ b/pkg/sentry/state/state_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package state diff --git a/pkg/sentry/state/state_unsafe_state_autogen.go b/pkg/sentry/state/state_unsafe_state_autogen.go new file mode 100644 index 000000000..6c2b29632 --- /dev/null +++ b/pkg/sentry/state/state_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package state diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD deleted file mode 100644 index 369541c7a..000000000 --- a/pkg/sentry/strace/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "strace", - srcs = [ - "capability.go", - "clone.go", - "epoll.go", - "futex.go", - "linux64_amd64.go", - "linux64_arm64.go", - "mmap.go", - "open.go", - "poll.go", - "ptrace.go", - "select.go", - "signal.go", - "socket.go", - "strace.go", - "syscalls.go", - ], - visibility = ["//:sandbox"], - deps = [ - ":strace_go_proto", - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/eventchannel", - "//pkg/hostarch", - "//pkg/marshal/primitive", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/kernel", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netlink", - "//pkg/sentry/syscalls/linux", - ], -) - -proto_library( - name = "strace", - srcs = ["strace.proto"], - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto deleted file mode 100644 index 906c52c51..000000000 --- a/pkg/sentry/strace/strace.proto +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -message Strace { - // Process name that made the syscall. - string process = 1; - - // Syscall function name. - string function = 2; - - // List of syscall arguments formatted as strings. - repeated string args = 3; - - oneof info { - StraceEnter enter = 4; - StraceExit exit = 5; - } -} - -message StraceEnter {} - -message StraceExit { - // Return value formatted as string. - string return = 1; - - // Formatted error string in case syscall failed. - string error = 2; - - // Value of errno upon syscall exit. - int64 err_no = 3; // errno is a macro and gets expanded :-( - - // Time elapsed between syscall enter and exit. - int64 elapsed_ns = 4; -} diff --git a/pkg/sentry/strace/strace_amd64_state_autogen.go b/pkg/sentry/strace/strace_amd64_state_autogen.go new file mode 100644 index 000000000..2a3d1077f --- /dev/null +++ b/pkg/sentry/strace/strace_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package strace diff --git a/pkg/sentry/strace/strace_arm64_state_autogen.go b/pkg/sentry/strace/strace_arm64_state_autogen.go new file mode 100644 index 000000000..c64c9df57 --- /dev/null +++ b/pkg/sentry/strace/strace_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package strace diff --git a/pkg/sentry/strace/strace_go_proto/strace.pb.go b/pkg/sentry/strace/strace_go_proto/strace.pb.go new file mode 100644 index 000000000..f3d5c1108 --- /dev/null +++ b/pkg/sentry/strace/strace_go_proto/strace.pb.go @@ -0,0 +1,362 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/sentry/strace/strace.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type Strace struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Process string `protobuf:"bytes,1,opt,name=process,proto3" json:"process,omitempty"` + Function string `protobuf:"bytes,2,opt,name=function,proto3" json:"function,omitempty"` + Args []string `protobuf:"bytes,3,rep,name=args,proto3" json:"args,omitempty"` + // Types that are assignable to Info: + // *Strace_Enter + // *Strace_Exit + Info isStrace_Info `protobuf_oneof:"info"` +} + +func (x *Strace) Reset() { + *x = Strace{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_strace_strace_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Strace) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Strace) ProtoMessage() {} + +func (x *Strace) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_strace_strace_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Strace.ProtoReflect.Descriptor instead. +func (*Strace) Descriptor() ([]byte, []int) { + return file_pkg_sentry_strace_strace_proto_rawDescGZIP(), []int{0} +} + +func (x *Strace) GetProcess() string { + if x != nil { + return x.Process + } + return "" +} + +func (x *Strace) GetFunction() string { + if x != nil { + return x.Function + } + return "" +} + +func (x *Strace) GetArgs() []string { + if x != nil { + return x.Args + } + return nil +} + +func (m *Strace) GetInfo() isStrace_Info { + if m != nil { + return m.Info + } + return nil +} + +func (x *Strace) GetEnter() *StraceEnter { + if x, ok := x.GetInfo().(*Strace_Enter); ok { + return x.Enter + } + return nil +} + +func (x *Strace) GetExit() *StraceExit { + if x, ok := x.GetInfo().(*Strace_Exit); ok { + return x.Exit + } + return nil +} + +type isStrace_Info interface { + isStrace_Info() +} + +type Strace_Enter struct { + Enter *StraceEnter `protobuf:"bytes,4,opt,name=enter,proto3,oneof"` +} + +type Strace_Exit struct { + Exit *StraceExit `protobuf:"bytes,5,opt,name=exit,proto3,oneof"` +} + +func (*Strace_Enter) isStrace_Info() {} + +func (*Strace_Exit) isStrace_Info() {} + +type StraceEnter struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *StraceEnter) Reset() { + *x = StraceEnter{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_strace_strace_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StraceEnter) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StraceEnter) ProtoMessage() {} + +func (x *StraceEnter) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_strace_strace_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StraceEnter.ProtoReflect.Descriptor instead. +func (*StraceEnter) Descriptor() ([]byte, []int) { + return file_pkg_sentry_strace_strace_proto_rawDescGZIP(), []int{1} +} + +type StraceExit struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Return string `protobuf:"bytes,1,opt,name=return,proto3" json:"return,omitempty"` + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` + ErrNo int64 `protobuf:"varint,3,opt,name=err_no,json=errNo,proto3" json:"err_no,omitempty"` + ElapsedNs int64 `protobuf:"varint,4,opt,name=elapsed_ns,json=elapsedNs,proto3" json:"elapsed_ns,omitempty"` +} + +func (x *StraceExit) Reset() { + *x = StraceExit{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_strace_strace_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StraceExit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StraceExit) ProtoMessage() {} + +func (x *StraceExit) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_strace_strace_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StraceExit.ProtoReflect.Descriptor instead. +func (*StraceExit) Descriptor() ([]byte, []int) { + return file_pkg_sentry_strace_strace_proto_rawDescGZIP(), []int{2} +} + +func (x *StraceExit) GetReturn() string { + if x != nil { + return x.Return + } + return "" +} + +func (x *StraceExit) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *StraceExit) GetErrNo() int64 { + if x != nil { + return x.ErrNo + } + return 0 +} + +func (x *StraceExit) GetElapsedNs() int64 { + if x != nil { + return x.ElapsedNs + } + return 0 +} + +var File_pkg_sentry_strace_strace_proto protoreflect.FileDescriptor + +var file_pkg_sentry_strace_strace_proto_rawDesc = []byte{ + 0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2f, 0x73, 0x74, 0x72, + 0x61, 0x63, 0x65, 0x2f, 0x73, 0x74, 0x72, 0x61, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x06, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x22, 0xb1, 0x01, 0x0a, 0x06, 0x53, 0x74, 0x72, + 0x61, 0x63, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x12, 0x1a, 0x0a, + 0x08, 0x66, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x66, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x61, 0x72, 0x67, + 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x61, 0x72, 0x67, 0x73, 0x12, 0x2b, 0x0a, + 0x05, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x67, + 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x53, 0x74, 0x72, 0x61, 0x63, 0x65, 0x45, 0x6e, 0x74, 0x65, + 0x72, 0x48, 0x00, 0x52, 0x05, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x28, 0x0a, 0x04, 0x65, 0x78, + 0x69, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, + 0x72, 0x2e, 0x53, 0x74, 0x72, 0x61, 0x63, 0x65, 0x45, 0x78, 0x69, 0x74, 0x48, 0x00, 0x52, 0x04, + 0x65, 0x78, 0x69, 0x74, 0x42, 0x06, 0x0a, 0x04, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x0d, 0x0a, 0x0b, + 0x53, 0x74, 0x72, 0x61, 0x63, 0x65, 0x45, 0x6e, 0x74, 0x65, 0x72, 0x22, 0x70, 0x0a, 0x0a, 0x53, + 0x74, 0x72, 0x61, 0x63, 0x65, 0x45, 0x78, 0x69, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x74, + 0x75, 0x72, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x65, 0x74, 0x75, 0x72, + 0x6e, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x15, 0x0a, 0x06, 0x65, 0x72, 0x72, 0x5f, 0x6e, + 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x65, 0x72, 0x72, 0x4e, 0x6f, 0x12, 0x1d, + 0x0a, 0x0a, 0x65, 0x6c, 0x61, 0x70, 0x73, 0x65, 0x64, 0x5f, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x09, 0x65, 0x6c, 0x61, 0x70, 0x73, 0x65, 0x64, 0x4e, 0x73, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_sentry_strace_strace_proto_rawDescOnce sync.Once + file_pkg_sentry_strace_strace_proto_rawDescData = file_pkg_sentry_strace_strace_proto_rawDesc +) + +func file_pkg_sentry_strace_strace_proto_rawDescGZIP() []byte { + file_pkg_sentry_strace_strace_proto_rawDescOnce.Do(func() { + file_pkg_sentry_strace_strace_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_sentry_strace_strace_proto_rawDescData) + }) + return file_pkg_sentry_strace_strace_proto_rawDescData +} + +var file_pkg_sentry_strace_strace_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_pkg_sentry_strace_strace_proto_goTypes = []interface{}{ + (*Strace)(nil), // 0: gvisor.Strace + (*StraceEnter)(nil), // 1: gvisor.StraceEnter + (*StraceExit)(nil), // 2: gvisor.StraceExit +} +var file_pkg_sentry_strace_strace_proto_depIdxs = []int32{ + 1, // 0: gvisor.Strace.enter:type_name -> gvisor.StraceEnter + 2, // 1: gvisor.Strace.exit:type_name -> gvisor.StraceExit + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_pkg_sentry_strace_strace_proto_init() } +func file_pkg_sentry_strace_strace_proto_init() { + if File_pkg_sentry_strace_strace_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_sentry_strace_strace_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Strace); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_sentry_strace_strace_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StraceEnter); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_sentry_strace_strace_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StraceExit); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pkg_sentry_strace_strace_proto_msgTypes[0].OneofWrappers = []interface{}{ + (*Strace_Enter)(nil), + (*Strace_Exit)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_sentry_strace_strace_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_sentry_strace_strace_proto_goTypes, + DependencyIndexes: file_pkg_sentry_strace_strace_proto_depIdxs, + MessageInfos: file_pkg_sentry_strace_strace_proto_msgTypes, + }.Build() + File_pkg_sentry_strace_strace_proto = out.File + file_pkg_sentry_strace_strace_proto_rawDesc = nil + file_pkg_sentry_strace_strace_proto_goTypes = nil + file_pkg_sentry_strace_strace_proto_depIdxs = nil +} diff --git a/pkg/sentry/strace/strace_state_autogen.go b/pkg/sentry/strace/strace_state_autogen.go new file mode 100644 index 000000000..33f6a7a54 --- /dev/null +++ b/pkg/sentry/strace/strace_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package strace diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD deleted file mode 100644 index 7a7c80ac6..000000000 --- a/pkg/sentry/syscalls/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "syscalls", - srcs = [ - "epoll.go", - "syscalls.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/errors/linuxerr", - "//pkg/sentry/arch", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/time", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD deleted file mode 100644 index 394396cde..000000000 --- a/pkg/sentry/syscalls/linux/BUILD +++ /dev/null @@ -1,111 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "linux", - srcs = [ - "error.go", - "flags.go", - "linux64.go", - "sigset.go", - "sys_aio.go", - "sys_capability.go", - "sys_clone_amd64.go", - "sys_clone_arm64.go", - "sys_epoll.go", - "sys_eventfd.go", - "sys_file.go", - "sys_futex.go", - "sys_getdents.go", - "sys_identity.go", - "sys_inotify.go", - "sys_lseek.go", - "sys_membarrier.go", - "sys_mempolicy.go", - "sys_mmap.go", - "sys_mount.go", - "sys_msgqueue.go", - "sys_pipe.go", - "sys_poll.go", - "sys_prctl.go", - "sys_random.go", - "sys_read.go", - "sys_rlimit.go", - "sys_rseq.go", - "sys_rusage.go", - "sys_sched.go", - "sys_seccomp.go", - "sys_sem.go", - "sys_shm.go", - "sys_signal.go", - "sys_socket.go", - "sys_splice.go", - "sys_stat.go", - "sys_stat_amd64.go", - "sys_stat_arm64.go", - "sys_sync.go", - "sys_sysinfo.go", - "sys_syslog.go", - "sys_thread.go", - "sys_time.go", - "sys_timer.go", - "sys_timerfd.go", - "sys_tls_amd64.go", - "sys_tls_arm64.go", - "sys_utsname.go", - "sys_write.go", - "sys_xattr.go", - "timespec.go", - ], - marshal = True, - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/bpf", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/metric", - "//pkg/rand", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fs/timerfd", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/fsbridge", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/eventfd", - "//pkg/sentry/kernel/fasync", - "//pkg/sentry/kernel/ipc", - "//pkg/sentry/kernel/msgqueue", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/kernel/signalfd", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/syscalls", - "//pkg/sentry/usage", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/syscalls/linux/linux_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/linux_abi_autogen_unsafe.go new file mode 100644 index 000000000..409927919 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_abi_autogen_unsafe.go @@ -0,0 +1,712 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*MessageHeader64)(nil) +var _ marshal.Marshallable = (*SchedParam)(nil) +var _ marshal.Marshallable = (*direntHdr)(nil) +var _ marshal.Marshallable = (*multipleMessageHeader64)(nil) +var _ marshal.Marshallable = (*oldDirentHdr)(nil) +var _ marshal.Marshallable = (*rlimit64)(nil) +var _ marshal.Marshallable = (*userSockFprog)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (d *direntHdr) SizeBytes() int { + return 1 + + (*oldDirentHdr)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (d *direntHdr) MarshalBytes(dst []byte) { + d.OldHdr.MarshalBytes(dst[:d.OldHdr.SizeBytes()]) + dst = dst[d.OldHdr.SizeBytes():] + dst[0] = byte(d.Typ) + dst = dst[1:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (d *direntHdr) UnmarshalBytes(src []byte) { + d.OldHdr.UnmarshalBytes(src[:d.OldHdr.SizeBytes()]) + src = src[d.OldHdr.SizeBytes():] + d.Typ = uint8(src[0]) + src = src[1:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (d *direntHdr) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (d *direntHdr) MarshalUnsafe(dst []byte) { + // Type direntHdr doesn't have a packed layout in memory, fallback to MarshalBytes. + d.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (d *direntHdr) UnmarshalUnsafe(src []byte) { + // Type direntHdr doesn't have a packed layout in memory, fallback to UnmarshalBytes. + d.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (d *direntHdr) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type direntHdr doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(d.SizeBytes()) // escapes: okay. + d.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (d *direntHdr) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return d.CopyOutN(cc, addr, d.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (d *direntHdr) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type direntHdr doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(d.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + d.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (d *direntHdr) WriteTo(writer io.Writer) (int64, error) { + // Type direntHdr doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, d.SizeBytes()) + d.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (o *oldDirentHdr) SizeBytes() int { + return 18 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (o *oldDirentHdr) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(o.Ino)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(o.Off)) + dst = dst[8:] + hostarch.ByteOrder.PutUint16(dst[:2], uint16(o.Reclen)) + dst = dst[2:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (o *oldDirentHdr) UnmarshalBytes(src []byte) { + o.Ino = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + o.Off = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + o.Reclen = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (o *oldDirentHdr) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (o *oldDirentHdr) MarshalUnsafe(dst []byte) { + // Type oldDirentHdr doesn't have a packed layout in memory, fallback to MarshalBytes. + o.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (o *oldDirentHdr) UnmarshalUnsafe(src []byte) { + // Type oldDirentHdr doesn't have a packed layout in memory, fallback to UnmarshalBytes. + o.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (o *oldDirentHdr) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Type oldDirentHdr doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(o.SizeBytes()) // escapes: okay. + o.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (o *oldDirentHdr) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return o.CopyOutN(cc, addr, o.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (o *oldDirentHdr) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Type oldDirentHdr doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(o.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + o.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (o *oldDirentHdr) WriteTo(writer io.Writer) (int64, error) { + // Type oldDirentHdr doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, o.SizeBytes()) + o.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (r *rlimit64) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (r *rlimit64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.Cur)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(r.Max)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (r *rlimit64) UnmarshalBytes(src []byte) { + r.Cur = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + r.Max = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (r *rlimit64) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (r *rlimit64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(r), uintptr(r.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (r *rlimit64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(r), unsafe.Pointer(&src[0]), uintptr(r.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (r *rlimit64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (r *rlimit64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return r.CopyOutN(cc, addr, r.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (r *rlimit64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (r *rlimit64) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(r))) + hdr.Len = r.SizeBytes() + hdr.Cap = r.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that r + // must live until the use above. + runtime.KeepAlive(r) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *SchedParam) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *SchedParam) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint32(dst[:4], uint32(s.schedPriority)) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *SchedParam) UnmarshalBytes(src []byte) { + s.schedPriority = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *SchedParam) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *SchedParam) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *SchedParam) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *SchedParam) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *SchedParam) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *SchedParam) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *SchedParam) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (u *userSockFprog) SizeBytes() int { + return 10 + + 1*6 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (u *userSockFprog) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint16(dst[:2], uint16(u.Len)) + dst = dst[2:] + // Padding: dst[:sizeof(byte)*6] ~= [6]byte{0} + dst = dst[1*(6):] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(u.Filter)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (u *userSockFprog) UnmarshalBytes(src []byte) { + u.Len = uint16(hostarch.ByteOrder.Uint16(src[:2])) + src = src[2:] + // Padding: ~ copy([6]byte(u._), src[:sizeof(byte)*6]) + src = src[1*(6):] + u.Filter = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (u *userSockFprog) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (u *userSockFprog) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(u), uintptr(u.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (u *userSockFprog) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(u), unsafe.Pointer(&src[0]), uintptr(u.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (u *userSockFprog) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (u *userSockFprog) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return u.CopyOutN(cc, addr, u.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (u *userSockFprog) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (u *userSockFprog) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(u))) + hdr.Len = u.SizeBytes() + hdr.Cap = u.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that u + // must live until the use above. + runtime.KeepAlive(u) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MessageHeader64) SizeBytes() int { + return 56 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MessageHeader64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.Name)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.NameLen)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.Iov)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.IovLen)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.Control)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.ControlLen)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.Flags)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MessageHeader64) UnmarshalBytes(src []byte) { + m.Name = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.NameLen = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + m.Iov = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.IovLen = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.Control = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.ControlLen = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.Flags = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (m *MessageHeader64) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (m *MessageHeader64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(m), uintptr(m.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (m *MessageHeader64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(m), unsafe.Pointer(&src[0]), uintptr(m.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (m *MessageHeader64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (m *MessageHeader64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return m.CopyOutN(cc, addr, m.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (m *MessageHeader64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (m *MessageHeader64) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *multipleMessageHeader64) SizeBytes() int { + return 8 + + (*MessageHeader64)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *multipleMessageHeader64) MarshalBytes(dst []byte) { + m.msgHdr.MarshalBytes(dst[:m.msgHdr.SizeBytes()]) + dst = dst[m.msgHdr.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.msgLen)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *multipleMessageHeader64) UnmarshalBytes(src []byte) { + m.msgHdr.UnmarshalBytes(src[:m.msgHdr.SizeBytes()]) + src = src[m.msgHdr.SizeBytes():] + m.msgLen = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (m *multipleMessageHeader64) Packed() bool { + return m.msgHdr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (m *multipleMessageHeader64) MarshalUnsafe(dst []byte) { + if m.msgHdr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(m), uintptr(m.SizeBytes())) + } else { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fallback to MarshalBytes. + m.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (m *multipleMessageHeader64) UnmarshalUnsafe(src []byte) { + if m.msgHdr.Packed() { + gohacks.Memmove(unsafe.Pointer(m), unsafe.Pointer(&src[0]), uintptr(m.SizeBytes())) + } else { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + m.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (m *multipleMessageHeader64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !m.msgHdr.Packed() { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + m.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (m *multipleMessageHeader64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return m.CopyOutN(cc, addr, m.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (m *multipleMessageHeader64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !m.msgHdr.Packed() { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + m.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (m *multipleMessageHeader64) WriteTo(writer io.Writer) (int64, error) { + if !m.msgHdr.Packed() { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, m.SizeBytes()) + m.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/syscalls/linux/linux_amd64_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/linux_amd64_abi_autogen_unsafe.go new file mode 100644 index 000000000..84d67731a --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_amd64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package linux + +import ( +) + diff --git a/pkg/sentry/syscalls/linux/linux_amd64_state_autogen.go b/pkg/sentry/syscalls/linux/linux_amd64_state_autogen.go new file mode 100644 index 000000000..19f7a855e --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 && amd64 && amd64 +// +build amd64,amd64,amd64 + +package linux diff --git a/pkg/sentry/syscalls/linux/linux_arm64_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/linux_arm64_abi_autogen_unsafe.go new file mode 100644 index 000000000..6bbdf8d79 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_arm64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package linux + +import ( +) + diff --git a/pkg/sentry/syscalls/linux/linux_arm64_state_autogen.go b/pkg/sentry/syscalls/linux/linux_arm64_state_autogen.go new file mode 100644 index 000000000..ecc524560 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 && arm64 && arm64 +// +build arm64,arm64,arm64 + +package linux diff --git a/pkg/sentry/syscalls/linux/linux_state_autogen.go b/pkg/sentry/syscalls/linux/linux_state_autogen.go new file mode 100644 index 000000000..cc577e4ac --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_state_autogen.go @@ -0,0 +1,112 @@ +// automatically generated by stateify. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *futexWaitRestartBlock) StateTypeName() string { + return "pkg/sentry/syscalls/linux.futexWaitRestartBlock" +} + +func (f *futexWaitRestartBlock) StateFields() []string { + return []string{ + "duration", + "addr", + "private", + "val", + "mask", + } +} + +func (f *futexWaitRestartBlock) beforeSave() {} + +// +checklocksignore +func (f *futexWaitRestartBlock) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.duration) + stateSinkObject.Save(1, &f.addr) + stateSinkObject.Save(2, &f.private) + stateSinkObject.Save(3, &f.val) + stateSinkObject.Save(4, &f.mask) +} + +func (f *futexWaitRestartBlock) afterLoad() {} + +// +checklocksignore +func (f *futexWaitRestartBlock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.duration) + stateSourceObject.Load(1, &f.addr) + stateSourceObject.Load(2, &f.private) + stateSourceObject.Load(3, &f.val) + stateSourceObject.Load(4, &f.mask) +} + +func (p *pollRestartBlock) StateTypeName() string { + return "pkg/sentry/syscalls/linux.pollRestartBlock" +} + +func (p *pollRestartBlock) StateFields() []string { + return []string{ + "pfdAddr", + "nfds", + "timeout", + } +} + +func (p *pollRestartBlock) beforeSave() {} + +// +checklocksignore +func (p *pollRestartBlock) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.pfdAddr) + stateSinkObject.Save(1, &p.nfds) + stateSinkObject.Save(2, &p.timeout) +} + +func (p *pollRestartBlock) afterLoad() {} + +// +checklocksignore +func (p *pollRestartBlock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.pfdAddr) + stateSourceObject.Load(1, &p.nfds) + stateSourceObject.Load(2, &p.timeout) +} + +func (n *clockNanosleepRestartBlock) StateTypeName() string { + return "pkg/sentry/syscalls/linux.clockNanosleepRestartBlock" +} + +func (n *clockNanosleepRestartBlock) StateFields() []string { + return []string{ + "c", + "end", + "rem", + } +} + +func (n *clockNanosleepRestartBlock) beforeSave() {} + +// +checklocksignore +func (n *clockNanosleepRestartBlock) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.c) + stateSinkObject.Save(1, &n.end) + stateSinkObject.Save(2, &n.rem) +} + +func (n *clockNanosleepRestartBlock) afterLoad() {} + +// +checklocksignore +func (n *clockNanosleepRestartBlock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.c) + stateSourceObject.Load(1, &n.end) + stateSourceObject.Load(2, &n.rem) +} + +func init() { + state.Register((*futexWaitRestartBlock)(nil)) + state.Register((*pollRestartBlock)(nil)) + state.Register((*clockNanosleepRestartBlock)(nil)) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD deleted file mode 100644 index 1e3bd2a50..000000000 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ /dev/null @@ -1,79 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "vfs2", - srcs = [ - "aio.go", - "epoll.go", - "eventfd.go", - "execve.go", - "fd.go", - "filesystem.go", - "fscontext.go", - "getdents.go", - "inotify.go", - "ioctl.go", - "lock.go", - "memfd.go", - "mmap.go", - "mount.go", - "path.go", - "pipe.go", - "poll.go", - "read_write.go", - "setstat.go", - "signal.go", - "socket.go", - "splice.go", - "stat.go", - "stat_amd64.go", - "stat_arm64.go", - "sync.go", - "timerfd.go", - "vfs2.go", - "xattr.go", - ], - marshal = True, - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fspath", - "//pkg/gohacks", - "//pkg/hostarch", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/sentry/arch", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsbridge", - "//pkg/sentry/fsimpl/eventfd", - "//pkg/sentry/fsimpl/pipefs", - "//pkg/sentry/fsimpl/signalfd", - "//pkg/sentry/fsimpl/timerfd", - "//pkg/sentry/fsimpl/tmpfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/fasync", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/syscalls", - "//pkg/sentry/syscalls/linux", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_abi_autogen_unsafe.go new file mode 100644 index 000000000..01b0d465f --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_abi_autogen_unsafe.go @@ -0,0 +1,366 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "io" + "reflect" + "runtime" + "unsafe" +) + +// Marshallable types used by this file. +var _ marshal.Marshallable = (*MessageHeader64)(nil) +var _ marshal.Marshallable = (*multipleMessageHeader64)(nil) +var _ marshal.Marshallable = (*sigSetWithSize)(nil) + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *sigSetWithSize) SizeBytes() int { + return 16 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *sigSetWithSize) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.sigsetAddr)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(s.sizeofSigset)) + dst = dst[8:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (s *sigSetWithSize) UnmarshalBytes(src []byte) { + s.sigsetAddr = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + s.sizeofSigset = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (s *sigSetWithSize) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (s *sigSetWithSize) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(s), uintptr(s.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (s *sigSetWithSize) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(s), unsafe.Pointer(&src[0]), uintptr(s.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (s *sigSetWithSize) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (s *sigSetWithSize) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return s.CopyOutN(cc, addr, s.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (s *sigSetWithSize) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (s *sigSetWithSize) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(s))) + hdr.Len = s.SizeBytes() + hdr.Cap = s.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that s + // must live until the use above. + runtime.KeepAlive(s) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *MessageHeader64) SizeBytes() int { + return 56 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *MessageHeader64) MarshalBytes(dst []byte) { + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.Name)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.NameLen)) + dst = dst[4:] + // Padding: dst[:sizeof(uint32)] ~= uint32(0) + dst = dst[4:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.Iov)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.IovLen)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.Control)) + dst = dst[8:] + hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.ControlLen)) + dst = dst[8:] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.Flags)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *MessageHeader64) UnmarshalBytes(src []byte) { + m.Name = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.NameLen = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ uint32 ~= src[:sizeof(uint32)] + src = src[4:] + m.Iov = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.IovLen = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.Control = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.ControlLen = uint64(hostarch.ByteOrder.Uint64(src[:8])) + src = src[8:] + m.Flags = int32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (m *MessageHeader64) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (m *MessageHeader64) MarshalUnsafe(dst []byte) { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(m), uintptr(m.SizeBytes())) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (m *MessageHeader64) UnmarshalUnsafe(src []byte) { + gohacks.Memmove(unsafe.Pointer(m), unsafe.Pointer(&src[0]), uintptr(m.SizeBytes())) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (m *MessageHeader64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (m *MessageHeader64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return m.CopyOutN(cc, addr, m.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (m *MessageHeader64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (m *MessageHeader64) WriteTo(writer io.Writer) (int64, error) { + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return int64(length), err +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (m *multipleMessageHeader64) SizeBytes() int { + return 8 + + (*MessageHeader64)(nil).SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (m *multipleMessageHeader64) MarshalBytes(dst []byte) { + m.msgHdr.MarshalBytes(dst[:m.msgHdr.SizeBytes()]) + dst = dst[m.msgHdr.SizeBytes():] + hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.msgLen)) + dst = dst[4:] + // Padding: dst[:sizeof(int32)] ~= int32(0) + dst = dst[4:] +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (m *multipleMessageHeader64) UnmarshalBytes(src []byte) { + m.msgHdr.UnmarshalBytes(src[:m.msgHdr.SizeBytes()]) + src = src[m.msgHdr.SizeBytes():] + m.msgLen = uint32(hostarch.ByteOrder.Uint32(src[:4])) + src = src[4:] + // Padding: var _ int32 ~= src[:sizeof(int32)] + src = src[4:] +} + +// Packed implements marshal.Marshallable.Packed. +//go:nosplit +func (m *multipleMessageHeader64) Packed() bool { + return m.msgHdr.Packed() +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (m *multipleMessageHeader64) MarshalUnsafe(dst []byte) { + if m.msgHdr.Packed() { + gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(m), uintptr(m.SizeBytes())) + } else { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fallback to MarshalBytes. + m.MarshalBytes(dst) + } +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (m *multipleMessageHeader64) UnmarshalUnsafe(src []byte) { + if m.msgHdr.Packed() { + gohacks.Memmove(unsafe.Pointer(m), unsafe.Pointer(&src[0]), uintptr(m.SizeBytes())) + } else { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fallback to UnmarshalBytes. + m.UnmarshalBytes(src) + } +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +//go:nosplit +func (m *multipleMessageHeader64) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { + if !m.msgHdr.Packed() { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + m.MarshalBytes(buf) // escapes: fallback. + return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +//go:nosplit +func (m *multipleMessageHeader64) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + return m.CopyOutN(cc, addr, m.SizeBytes()) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +//go:nosplit +func (m *multipleMessageHeader64) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) { + if !m.msgHdr.Packed() { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fall back to UnmarshalBytes. + buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + m.UnmarshalBytes(buf) // escapes: fallback. + return length, err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return length, err +} + +// WriteTo implements io.WriterTo.WriteTo. +func (m *multipleMessageHeader64) WriteTo(writer io.Writer) (int64, error) { + if !m.msgHdr.Packed() { + // Type multipleMessageHeader64 doesn't have a packed layout in memory, fall back to MarshalBytes. + buf := make([]byte, m.SizeBytes()) + m.MarshalBytes(buf) + length, err := writer.Write(buf) + return int64(length), err + } + + // Construct a slice backed by dst's underlying memory. + var buf []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) + hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m))) + hdr.Len = m.SizeBytes() + hdr.Cap = m.SizeBytes() + + length, err := writer.Write(buf) + // Since we bypassed the compiler's escape analysis, indicate that m + // must live until the use above. + runtime.KeepAlive(m) // escapes: replaced by intrinsic. + return int64(length), err +} + diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_abi_autogen_unsafe.go new file mode 100644 index 000000000..d5335190c --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build amd64 +// +build amd64 + +package vfs2 + +import ( +) + diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_state_autogen.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_state_autogen.go new file mode 100644 index 000000000..467b4258c --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package vfs2 diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_abi_autogen_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_abi_autogen_unsafe.go new file mode 100644 index 000000000..07ab07d17 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_abi_autogen_unsafe.go @@ -0,0 +1,16 @@ +// Automatically generated marshal implementation. See tools/go_marshal. + +// If there are issues with build constraint aggregation, see +// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here +// come from the input set of files used to generate this file. This input set +// is filtered based on pre-defined file suffixes related to build constraints, +// see tools/defs.bzl:calculate_sets(). + +//go:build arm64 +// +build arm64 + +package vfs2 + +import ( +) + diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_state_autogen.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_state_autogen.go new file mode 100644 index 000000000..2e1a50af5 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package vfs2 diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2_state_autogen.go b/pkg/sentry/syscalls/linux/vfs2/vfs2_state_autogen.go new file mode 100644 index 000000000..d02c8467f --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *pollRestartBlock) StateTypeName() string { + return "pkg/sentry/syscalls/linux/vfs2.pollRestartBlock" +} + +func (p *pollRestartBlock) StateFields() []string { + return []string{ + "pfdAddr", + "nfds", + "timeout", + } +} + +func (p *pollRestartBlock) beforeSave() {} + +// +checklocksignore +func (p *pollRestartBlock) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.pfdAddr) + stateSinkObject.Save(1, &p.nfds) + stateSinkObject.Save(2, &p.timeout) +} + +func (p *pollRestartBlock) afterLoad() {} + +// +checklocksignore +func (p *pollRestartBlock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.pfdAddr) + stateSourceObject.Load(1, &p.nfds) + stateSourceObject.Load(2, &p.timeout) +} + +func init() { + state.Register((*pollRestartBlock)(nil)) +} diff --git a/pkg/sentry/syscalls/syscalls_state_autogen.go b/pkg/sentry/syscalls/syscalls_state_autogen.go new file mode 100644 index 000000000..b577e39a3 --- /dev/null +++ b/pkg/sentry/syscalls/syscalls_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package syscalls diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD deleted file mode 100644 index c21971322..000000000 --- a/pkg/sentry/time/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_parameters", - out = "seqatomic_parameters_unsafe.go", - package = "time", - suffix = "Parameters", - template = "//pkg/sync/seqatomic:generic_seqatomic", - types = { - "Value": "Parameters", - }, -) - -go_library( - name = "time", - srcs = [ - "arith_arm64.go", - "calibrated_clock.go", - "clock_id.go", - "clocks.go", - "muldiv_amd64.s", - "muldiv_arm64.s", - "parameters.go", - "sampler.go", - "sampler_amd64.go", - "sampler_arm64.go", - "sampler_unsafe.go", - "seqatomic_parameters_unsafe.go", - "tsc_amd64.s", - "tsc_arm64.s", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/errors/linuxerr", - "//pkg/gohacks", - "//pkg/log", - "//pkg/metric", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "time_test", - srcs = [ - "calibrated_clock_test.go", - "parameters_test.go", - "sampler_test.go", - ], - library = ":time", -) diff --git a/pkg/sentry/time/LICENSE b/pkg/sentry/time/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sentry/time/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sentry/time/calibrated_clock_test.go b/pkg/sentry/time/calibrated_clock_test.go deleted file mode 100644 index 0a4b1f1bf..000000000 --- a/pkg/sentry/time/calibrated_clock_test.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "testing" - "time" -) - -// newTestCalibratedClock returns a CalibratedClock that collects samples from -// the given sample list and cycle counts from the given cycle list. -func newTestCalibratedClock(samples []sample, cycles []TSCValue) *CalibratedClock { - return &CalibratedClock{ - ref: newTestSampler(samples, cycles), - } -} - -func TestConstantFrequency(t *testing.T) { - // Perfectly constant frequency. - samples := []sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200}, - {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300}, - {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400}, - {before: 500000, after: 500000 + defaultOverheadCycles, ref: 500}, - {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600}, - {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700}, - } - - c := newTestCalibratedClock(samples, nil) - - // Update from all samples. - for range samples { - c.Update() - } - - c.mu.RLock() - if !c.ready { - c.mu.RUnlock() - t.Fatalf("clock not ready") - return // For checklocks consistency. - } - // A bit after the last sample. - now, ok := c.params.ComputeTime(750000) - c.mu.RUnlock() - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - - t.Logf("now: %v", now) - - // Time should be between the current sample and where we'd expect the - // next sample. - if now < 700 || now > 800 { - t.Errorf("now got %v want > 700 && < 800", now) - } -} - -func TestErrorCorrection(t *testing.T) { - testCases := []struct { - name string - samples [5]sample - projectedTimeStart int64 - projectedTimeEnd int64 - }{ - // Initial calibration should be ~1MHz for each of these, and - // the reference clock changes in samples[2]. - { - name: "slow-down", - samples: [5]sample{ - {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())}, - {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())}, - // Reference clock has slowed down, causing 100ms of error. - {before: 3010000, after: 3010001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())}, - {before: 4020000, after: 4020001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())}, - {before: 5030000, after: 5030001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())}, - }, - projectedTimeStart: 3005 * time.Millisecond.Nanoseconds(), - projectedTimeEnd: 3015 * time.Millisecond.Nanoseconds(), - }, - { - name: "speed-up", - samples: [5]sample{ - {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())}, - {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())}, - // Reference clock has sped up, causing 100ms of error. - {before: 2990000, after: 2990001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())}, - {before: 3980000, after: 3980001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())}, - {before: 4970000, after: 4970001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())}, - }, - projectedTimeStart: 2985 * time.Millisecond.Nanoseconds(), - projectedTimeEnd: 2995 * time.Millisecond.Nanoseconds(), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := newTestCalibratedClock(tc.samples[:], nil) - - // Initial calibration takes two updates. - _, ok := c.Update() - if ok { - t.Fatalf("Update ready too early") - } - - params, ok := c.Update() - if !ok { - t.Fatalf("Update not ready") - } - - // Initial calibration is ~1MHz. - hz := params.Frequency - if hz < 990000 || hz > 1010000 { - t.Fatalf("Frequency got %v want > 990kHz && < 1010kHz", hz) - } - - // Project time at the next update. Given the 1MHz - // calibration, it is expected to be ~3.1s/2.9s, not - // the actual 3s. - // - // N.B. the next update time is the "after" time above. - projected, ok := params.ComputeTime(tc.samples[2].after) - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - if projected < tc.projectedTimeStart || projected > tc.projectedTimeEnd { - t.Fatalf("ComputeTime(%v) got %v want > %v && < %v", tc.samples[2].after, projected, tc.projectedTimeStart, tc.projectedTimeEnd) - } - - // Update again to see the changed reference clock. - params, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - - // We now know that TSC = tc.samples[2].after -> 3s, - // but with the previous params indicated that TSC - // tc.samples[2].after -> 3.5s/2.5s. We can't allow the - // clock to go backwards, and having the clock jump - // forwards is undesirable. There should be a smooth - // transition that corrects the clock error over time. - // Check that the clock is continuous at TSC = - // tc.samples[2].after. - newProjected, ok := params.ComputeTime(tc.samples[2].after) - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - if newProjected != projected { - t.Errorf("Discontinuous time; ComputeTime(%v) got %v want %v", tc.samples[2].after, newProjected, projected) - } - - // As the reference clock stablizes, ensure that the clock error - // decreases. - initialErr := c.errorNS - t.Logf("initial error: %v ns", initialErr) - - _, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - if c.errorNS.Magnitude() > initialErr.Magnitude() { - t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr) - } - - _, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - if c.errorNS.Magnitude() > initialErr.Magnitude() { - t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr) - } - - t.Logf("final error: %v ns", c.errorNS) - }) - } -} diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go deleted file mode 100644 index 0ce1257f6..000000000 --- a/pkg/sentry/time/parameters_test.go +++ /dev/null @@ -1,501 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "math" - "testing" - "time" -) - -func TestParametersComputeTime(t *testing.T) { - testCases := []struct { - name string - params Parameters - now TSCValue - want int64 - }{ - { - // Now is the same as the base cycles. - name: "base-cycles", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 10000, - want: 5000 * time.Millisecond.Nanoseconds(), - }, - { - // Now is the behind the base cycles. Time is frozen. - name: "backwards", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 9000, - want: 5000 * time.Millisecond.Nanoseconds(), - }, - { - // Now is ahead of the base cycles. - name: "ahead", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 15000, - want: 5500 * time.Millisecond.Nanoseconds(), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, ok := tc.params.ComputeTime(tc.now) - if !ok { - t.Errorf("ComputeTime ok got %v want true", got) - } - if got != tc.want { - t.Errorf("ComputeTime got %+v want %+v", got, tc.want) - } - }) - } -} - -func TestParametersErrorAdjust(t *testing.T) { - testCases := []struct { - name string - oldParams Parameters - now TSCValue - newParams Parameters - want Parameters - errorNS ReferenceNS - wantErr bool - }{ - { - // newParams are perfectly continuous with oldParams - // and don't need adjustment. - name: "continuous", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - newParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - }, - { - // Same as "continuous", but with now ahead of - // newParams.BaseCycles. The result is the same as - // there is no error to correct. - name: "continuous-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 60000, - newParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - }, - { - // errorAdjust bails out if the TSC goes backwards. - name: "tsc-backwards", - oldParams: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 9000, - newParams: Parameters{ - BaseCycles: 9000, - BaseRef: ReferenceNS(1100 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - wantErr: true, - }, - { - // errorAdjust bails out if new params are from after now. - name: "params-after-now", - oldParams: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 11000, - newParams: Parameters{ - BaseCycles: 12000, - BaseRef: ReferenceNS(1200 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - wantErr: true, - }, - { - // Host clock sped up. - name: "speed-up", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 45000, - // Host frequency changed to 9000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - // From oldParams, we think ref = 4.5s at cycles = 45000. - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 9000, - }, - want: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(4500 * time.Millisecond.Nanoseconds()), - // We must decrease the new frequency by 50% to - // correct 0.5s of error in 1s - // (ApproxUpdateInterval). - Frequency: 4500, - }, - errorNS: ReferenceNS(-500 * time.Millisecond.Nanoseconds()), - }, - { - // Host clock sped up, with now ahead of newParams. - name: "speed-up-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - // Host frequency changed to 9000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 9000, - }, - // nextRef = 6000ms - // nextCycles = 9000 * (6000ms - 5000ms) + 45000 - // nextCycles = 9000 * (1s) + 45000 - // nextCycles = 54000 - // f = (54000 - 50000) / 1s = 4000 - // - // ref = 5000ms - (50000 - 45000) / 4000 - // ref = 3.75s - want: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(3750 * time.Millisecond.Nanoseconds()), - Frequency: 4000, - }, - // oldNow = 50000 * 10000 = 5s - // newNow = (50000 - 45000) / 9000 + 5s = 5.555s - errorNS: ReferenceNS((5000*time.Millisecond - 5555555555).Nanoseconds()), - }, - { - // Host clock sped up. The new parameters are so far - // ahead that the next update time already passed. - name: "speed-up-uncorrectable-baseref", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - // Host frequency changed to 5000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(9000 * time.Millisecond.Nanoseconds()), - Frequency: 5000, - }, - // The next update should be at 10s, but newParams - // already passed 6s. Thus it is impossible to correct - // the clock by then. - wantErr: true, - }, - { - // Host clock sped up. The new parameters are moving so - // fast that the next update should be before now. - name: "speed-up-uncorrectable-frequency", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 55000, - // Host frequency changed to 7500 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()), - Frequency: 7500, - }, - // The next update should be at 6.5s, but newParams are - // so far ahead and fast that they reach 6.5s at cycle - // 48750, which before now! Thus it is impossible to - // correct the clock by then. - wantErr: true, - }, - { - // Host clock slowed down. - name: "slow-down", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 55000, - // Host frequency changed to 11000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 55000, - // From oldParams, we think ref = 5.5s at cycles = 55000. - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 11000, - }, - want: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5500 * time.Millisecond.Nanoseconds()), - // We must increase the new frequency by 50% to - // correct 0.5s of error in 1s - // (ApproxUpdateInterval). - Frequency: 16500, - }, - errorNS: ReferenceNS(500 * time.Millisecond.Nanoseconds()), - }, - { - // Host clock slowed down, with now ahead of newParams. - name: "slow-down-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 60000, - // Host frequency changed to 11000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 11000, - }, - // nextRef = 7000ms - // nextCycles = 11000 * (7000ms - 5000ms) + 55000 - // nextCycles = 11000 * (2000ms) + 55000 - // nextCycles = 77000 - // f = (77000 - 60000) / 1s = 17000 - // - // ref = 6000ms - (60000 - 55000) / 17000 - // ref = 5705882353ns - want: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5705882353), - Frequency: 17000, - }, - // oldNow = 60000 * 10000 = 6s - // newNow = (60000 - 55000) / 11000 + 5s = 5.4545s - errorNS: ReferenceNS((6*time.Second - 5454545454).Nanoseconds()), - }, - { - // Host time went backwards. - name: "time-backwards", - oldParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 60000, - newParams: Parameters{ - BaseCycles: 60000, - // From oldParams, we think ref = 6s at cycles = 60000. - BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()), - // We must increase the frequency by 200% to - // correct 2s of error in 1s - // (ApproxUpdateInterval). - Frequency: 30000, - }, - errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()), - }, - { - // Host time went backwards, with now ahead of newParams. - name: "time-backwards-nowdiff", - oldParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 65000, - // nextRef = 7500ms - // nextCycles = 10000 * (7500ms - 4000ms) + 60000 - // nextCycles = 10000 * (3500ms) + 60000 - // nextCycles = 95000 - // f = (95000 - 65000) / 1s = 30000 - // - // ref = 6500ms - (65000 - 60000) / 30000 - // ref = 6333333333ns - newParams: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(6333333334), - Frequency: 30000, - }, - // oldNow = 65000 * 10000 = 6.5s - // newNow = (65000 - 60000) / 10000 + 4s = 4.5s - errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, errorNS, err := errorAdjust(tc.oldParams, tc.newParams, tc.now) - if err != nil && !tc.wantErr { - t.Errorf("err got %v want nil", err) - } else if err == nil && tc.wantErr { - t.Errorf("err got nil want non-nil") - } - - if got != tc.want { - t.Errorf("Parameters got %+v want %+v", got, tc.want) - } - if errorNS != tc.errorNS { - t.Errorf("errorNS got %v want %v", errorNS, tc.errorNS) - } - }) - } -} - -func testMuldiv(t *testing.T, v uint64) { - for i := uint64(1); i <= 1000000; i++ { - mult := uint64(1000000000) - div := i * mult - res, ok := muldiv64(v, mult, div) - if !ok { - t.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div) - } - if want := v / i; res != want { - t.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want) - } - } -} - -func TestMulDiv(t *testing.T) { - testMuldiv(t, math.MaxUint64) - for i := int64(-10); i <= 10; i++ { - testMuldiv(t, uint64(i)) - } -} - -func TestMulDivZero(t *testing.T) { - if r, ok := muldiv64(2, 4, 0); ok { - t.Errorf("muldiv64(2, 4, 0) got %d, ok want !ok", r) - } - - if r, ok := muldiv64(0, 0, 0); ok { - t.Errorf("muldiv64(0, 0, 0) got %d, ok want !ok", r) - } -} - -func TestMulDivOverflow(t *testing.T) { - testCases := []struct { - name string - val uint64 - mult uint64 - div uint64 - ok bool - ret uint64 - }{ - { - name: "2^62", - val: 1 << 63, - mult: 4, - div: 8, - ok: true, - ret: 1 << 62, - }, - { - name: "2^64-1", - val: 0xffffffffffffffff, - mult: 1, - div: 1, - ok: true, - ret: 0xffffffffffffffff, - }, - { - name: "2^64", - val: 1 << 63, - mult: 4, - div: 2, - ok: false, - }, - { - name: "2^125", - val: 1 << 63, - mult: 1 << 63, - div: 2, - ok: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - r, ok := muldiv64(tc.val, tc.mult, tc.div) - if ok != tc.ok { - t.Errorf("ok got %v want %v", ok, tc.ok) - } - if tc.ok && r != tc.ret { - t.Errorf("ret got %v want %v", r, tc.ret) - } - }) - } -} - -func BenchmarkMuldiv64(b *testing.B) { - var v uint64 = math.MaxUint64 - for i := uint64(1); i <= 1000000; i++ { - mult := uint64(1000000000) - div := i * mult - res, ok := muldiv64(v, mult, div) - if !ok { - b.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div) - } - if want := v / i; res != want { - b.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want) - } - } -} diff --git a/pkg/sentry/time/sampler_test.go b/pkg/sentry/time/sampler_test.go deleted file mode 100644 index 3e70a1134..000000000 --- a/pkg/sentry/time/sampler_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "errors" - "testing" -) - -// errNoSamples is returned when testReferenceClocks runs out of samples. -var errNoSamples = errors.New("no samples available") - -// testReferenceClocks returns a preset list of samples and cycle counts. -type testReferenceClocks struct { - samples []sample - cycles []TSCValue -} - -// Sample implements referenceClocks.Sample, returning the next sample in the list. -func (t *testReferenceClocks) Sample(_ ClockID) (sample, error) { - if len(t.samples) == 0 { - return sample{}, errNoSamples - } - - s := t.samples[0] - if len(t.samples) == 1 { - t.samples = nil - } else { - t.samples = t.samples[1:] - } - - return s, nil -} - -// Cycles implements referenceClocks.Cycles, returning the next TSCValue in the list. -func (t *testReferenceClocks) Cycles() TSCValue { - if len(t.cycles) == 0 { - return 0 - } - - c := t.cycles[0] - if len(t.cycles) == 1 { - t.cycles = nil - } else { - t.cycles = t.cycles[1:] - } - - return c -} - -// newTestSampler returns a sampler that collects samples from -// the given sample list and cycle counts from the given cycle list. -func newTestSampler(samples []sample, cycles []TSCValue) *sampler { - return &sampler{ - clocks: &testReferenceClocks{ - samples: samples, - cycles: cycles, - }, - overhead: defaultOverheadCycles, - } -} - -// generateSamples generates n samples with the given overhead. -func generateSamples(n int, overhead TSCValue) []sample { - samples := []sample{{before: 1000000, after: 1000000 + overhead, ref: 100}} - for i := 0; i < n-1; i++ { - prev := samples[len(samples)-1] - samples = append(samples, sample{ - before: prev.before + 1000000, - after: prev.after + 1000000, - ref: prev.ref + 100, - }) - } - return samples -} - -// TestSample ensures that samples can be collected. -func TestSample(t *testing.T) { - testCases := []struct { - name string - samples []sample - err error - }{ - { - name: "basic", - samples: []sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - }, - err: nil, - }, - { - // Sample with backwards TSC ignored. - // referenceClock should retry and get errNoSamples. - name: "backwards-tsc-ignored", - samples: []sample{ - {before: 100000, after: 90000, ref: 100}, - }, - err: errNoSamples, - }, - { - // Sample far above overhead skipped. - // referenceClock should retry and get errNoSamples. - name: "reject-overhead", - samples: []sample{ - {before: 100000, after: 100000 + 5*defaultOverheadCycles, ref: 100}, - }, - err: errNoSamples, - }, - { - // Maximum overhead allowed is bounded. - name: "over-max-overhead", - // Generate a bunch of samples. The reference clock - // needs a while to ramp up its expected overhead. - samples: generateSamples(100, 2*maxOverheadCycles), - err: errOverheadTooHigh, - }, - { - // Overhead at maximum overhead is allowed. - name: "max-overhead", - // Generate a bunch of samples. The reference clock - // needs a while to ramp up its expected overhead. - samples: generateSamples(100, maxOverheadCycles), - err: nil, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := newTestSampler(tc.samples, nil) - err := s.Sample() - if err != tc.err { - t.Errorf("Sample err got %v want %v", err, tc.err) - } - }) - } -} - -// TestOutliersIgnored tests that referenceClock ignores samples with very high -// overhead. -func TestOutliersIgnored(t *testing.T) { - s := newTestSampler([]sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200}, - {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300}, - {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400}, - {before: 500000, after: 500000 + 5*defaultOverheadCycles, ref: 500}, // Ignored - {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600}, - {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700}, - }, nil) - - // Collect 5 samples. - for i := 0; i < 5; i++ { - err := s.Sample() - if err != nil { - t.Fatalf("Unexpected error while sampling: %v", err) - } - } - - oldest, newest, ok := s.Range() - if !ok { - t.Fatalf("Range not ok") - } - - if oldest.ref != 100 { - t.Errorf("oldest.ref got %v want %v", oldest.ref, 100) - } - - // We skipped the high-overhead sample. - if newest.ref != 600 { - t.Errorf("newest.ref got %v want %v", newest.ref, 600) - } -} diff --git a/pkg/sentry/time/seqatomic_parameters_unsafe.go b/pkg/sentry/time/seqatomic_parameters_unsafe.go new file mode 100644 index 000000000..357e476ec --- /dev/null +++ b/pkg/sentry/time/seqatomic_parameters_unsafe.go @@ -0,0 +1,38 @@ +package time + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sync" +) + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in seq. +// +//go:nosplit +func SeqAtomicLoadParameters(seq *sync.SeqCount, ptr *Parameters) Parameters { + for { + if val, ok := SeqAtomicTryLoadParameters(seq, seq.BeginRead(), ptr); ok { + return val + } + } +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in seq initiated by a call to seq.BeginRead() that returned epoch. If the +// read would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +// +//go:nosplit +func SeqAtomicTryLoadParameters(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Parameters) (val Parameters, ok bool) { + if sync.RaceEnabled { + + gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + + val = *ptr + } + ok = seq.ReadOk(epoch) + return +} diff --git a/pkg/sentry/time/time_amd64_state_autogen.go b/pkg/sentry/time/time_amd64_state_autogen.go new file mode 100644 index 000000000..3306b8b88 --- /dev/null +++ b/pkg/sentry/time/time_amd64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build amd64 +// +build amd64 + +package time diff --git a/pkg/sentry/time/time_arm64_state_autogen.go b/pkg/sentry/time/time_arm64_state_autogen.go new file mode 100644 index 000000000..233d240c4 --- /dev/null +++ b/pkg/sentry/time/time_arm64_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build arm64 +// +build arm64 + +package time diff --git a/pkg/sentry/time/time_state_autogen.go b/pkg/sentry/time/time_state_autogen.go new file mode 100644 index 000000000..2adc9c9e0 --- /dev/null +++ b/pkg/sentry/time/time_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package time diff --git a/pkg/sentry/time/time_unsafe_state_autogen.go b/pkg/sentry/time/time_unsafe_state_autogen.go new file mode 100644 index 000000000..2adc9c9e0 --- /dev/null +++ b/pkg/sentry/time/time_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package time diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD deleted file mode 100644 index 5d4aa3a63..000000000 --- a/pkg/sentry/unimpl/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library", "proto_library") - -package(licenses = ["notice"]) - -proto_library( - name = "unimplemented_syscall", - srcs = ["unimplemented_syscall.proto"], - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_proto"], -) - -go_library( - name = "unimpl", - srcs = ["events.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/context", - "//pkg/log", - ], -) diff --git a/pkg/sentry/unimpl/unimpl_state_autogen.go b/pkg/sentry/unimpl/unimpl_state_autogen.go new file mode 100644 index 000000000..b37d16f87 --- /dev/null +++ b/pkg/sentry/unimpl/unimpl_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package unimpl diff --git a/pkg/sentry/unimpl/unimplemented_syscall.proto b/pkg/sentry/unimpl/unimplemented_syscall.proto deleted file mode 100644 index 0d7a94be7..000000000 --- a/pkg/sentry/unimpl/unimplemented_syscall.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "pkg/sentry/arch/registers.proto"; - -message UnimplementedSyscall { - // Task ID. - int32 tid = 1; - - // Registers at the time of the call. - Registers registers = 2; -} diff --git a/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go new file mode 100644 index 000000000..aa85c330a --- /dev/null +++ b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go @@ -0,0 +1,164 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.25.0 +// protoc v3.13.0 +// source: pkg/sentry/unimpl/unimplemented_syscall.proto + +package gvisor + +import ( + proto "github.com/golang/protobuf/proto" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + registers_go_proto "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type UnimplementedSyscall struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"` + Registers *registers_go_proto.Registers `protobuf:"bytes,2,opt,name=registers,proto3" json:"registers,omitempty"` +} + +func (x *UnimplementedSyscall) Reset() { + *x = UnimplementedSyscall{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_sentry_unimpl_unimplemented_syscall_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UnimplementedSyscall) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UnimplementedSyscall) ProtoMessage() {} + +func (x *UnimplementedSyscall) ProtoReflect() protoreflect.Message { + mi := &file_pkg_sentry_unimpl_unimplemented_syscall_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UnimplementedSyscall.ProtoReflect.Descriptor instead. +func (*UnimplementedSyscall) Descriptor() ([]byte, []int) { + return file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescGZIP(), []int{0} +} + +func (x *UnimplementedSyscall) GetTid() int32 { + if x != nil { + return x.Tid + } + return 0 +} + +func (x *UnimplementedSyscall) GetRegisters() *registers_go_proto.Registers { + if x != nil { + return x.Registers + } + return nil +} + +var File_pkg_sentry_unimpl_unimplemented_syscall_proto protoreflect.FileDescriptor + +var file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDesc = []byte{ + 0x0a, 0x2d, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x2f, 0x75, 0x6e, 0x69, + 0x6d, 0x70, 0x6c, 0x2f, 0x75, 0x6e, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x65, + 0x64, 0x5f, 0x73, 0x79, 0x73, 0x63, 0x61, 0x6c, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x06, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x1a, 0x1f, 0x70, 0x6b, 0x67, 0x2f, 0x73, 0x65, 0x6e, + 0x74, 0x72, 0x79, 0x2f, 0x61, 0x72, 0x63, 0x68, 0x2f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, + 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x59, 0x0a, 0x14, 0x55, 0x6e, 0x69, 0x6d, + 0x70, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x65, 0x64, 0x53, 0x79, 0x73, 0x63, 0x61, 0x6c, 0x6c, + 0x12, 0x10, 0x0a, 0x03, 0x74, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x74, + 0x69, 0x64, 0x12, 0x2f, 0x0a, 0x09, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x67, 0x76, 0x69, 0x73, 0x6f, 0x72, 0x2e, 0x52, + 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x73, 0x52, 0x09, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, + 0x65, 0x72, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescOnce sync.Once + file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescData = file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDesc +) + +func file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescGZIP() []byte { + file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescOnce.Do(func() { + file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescData) + }) + return file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDescData +} + +var file_pkg_sentry_unimpl_unimplemented_syscall_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_pkg_sentry_unimpl_unimplemented_syscall_proto_goTypes = []interface{}{ + (*UnimplementedSyscall)(nil), // 0: gvisor.UnimplementedSyscall + (*registers_go_proto.Registers)(nil), // 1: gvisor.Registers +} +var file_pkg_sentry_unimpl_unimplemented_syscall_proto_depIdxs = []int32{ + 1, // 0: gvisor.UnimplementedSyscall.registers:type_name -> gvisor.Registers + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_pkg_sentry_unimpl_unimplemented_syscall_proto_init() } +func file_pkg_sentry_unimpl_unimplemented_syscall_proto_init() { + if File_pkg_sentry_unimpl_unimplemented_syscall_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_sentry_unimpl_unimplemented_syscall_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UnimplementedSyscall); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pkg_sentry_unimpl_unimplemented_syscall_proto_goTypes, + DependencyIndexes: file_pkg_sentry_unimpl_unimplemented_syscall_proto_depIdxs, + MessageInfos: file_pkg_sentry_unimpl_unimplemented_syscall_proto_msgTypes, + }.Build() + File_pkg_sentry_unimpl_unimplemented_syscall_proto = out.File + file_pkg_sentry_unimpl_unimplemented_syscall_proto_rawDesc = nil + file_pkg_sentry_unimpl_unimplemented_syscall_proto_goTypes = nil + file_pkg_sentry_unimpl_unimplemented_syscall_proto_depIdxs = nil +} diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD deleted file mode 100644 index 7467e6398..000000000 --- a/pkg/sentry/uniqueid/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "uniqueid", - srcs = ["context.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/context", - "//pkg/sentry/socket/unix/transport", - ], -) diff --git a/pkg/sentry/uniqueid/uniqueid_state_autogen.go b/pkg/sentry/uniqueid/uniqueid_state_autogen.go new file mode 100644 index 000000000..1890fdf46 --- /dev/null +++ b/pkg/sentry/uniqueid/uniqueid_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package uniqueid diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD deleted file mode 100644 index 8e2b3ed79..000000000 --- a/pkg/sentry/usage/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "usage", - srcs = [ - "cpu.go", - "io.go", - "memory.go", - "memory_unsafe.go", - "usage.go", - ], - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/bits", - "//pkg/memutil", - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/usage/usage_state_autogen.go b/pkg/sentry/usage/usage_state_autogen.go new file mode 100644 index 000000000..a0077a6ca --- /dev/null +++ b/pkg/sentry/usage/usage_state_autogen.go @@ -0,0 +1,86 @@ +// automatically generated by stateify. + +package usage + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *CPUStats) StateTypeName() string { + return "pkg/sentry/usage.CPUStats" +} + +func (s *CPUStats) StateFields() []string { + return []string{ + "UserTime", + "SysTime", + "VoluntarySwitches", + } +} + +func (s *CPUStats) beforeSave() {} + +// +checklocksignore +func (s *CPUStats) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.UserTime) + stateSinkObject.Save(1, &s.SysTime) + stateSinkObject.Save(2, &s.VoluntarySwitches) +} + +func (s *CPUStats) afterLoad() {} + +// +checklocksignore +func (s *CPUStats) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.UserTime) + stateSourceObject.Load(1, &s.SysTime) + stateSourceObject.Load(2, &s.VoluntarySwitches) +} + +func (i *IO) StateTypeName() string { + return "pkg/sentry/usage.IO" +} + +func (i *IO) StateFields() []string { + return []string{ + "CharsRead", + "CharsWritten", + "ReadSyscalls", + "WriteSyscalls", + "BytesRead", + "BytesWritten", + "BytesWriteCancelled", + } +} + +func (i *IO) beforeSave() {} + +// +checklocksignore +func (i *IO) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.CharsRead) + stateSinkObject.Save(1, &i.CharsWritten) + stateSinkObject.Save(2, &i.ReadSyscalls) + stateSinkObject.Save(3, &i.WriteSyscalls) + stateSinkObject.Save(4, &i.BytesRead) + stateSinkObject.Save(5, &i.BytesWritten) + stateSinkObject.Save(6, &i.BytesWriteCancelled) +} + +func (i *IO) afterLoad() {} + +// +checklocksignore +func (i *IO) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.CharsRead) + stateSourceObject.Load(1, &i.CharsWritten) + stateSourceObject.Load(2, &i.ReadSyscalls) + stateSourceObject.Load(3, &i.WriteSyscalls) + stateSourceObject.Load(4, &i.BytesRead) + stateSourceObject.Load(5, &i.BytesWritten) + stateSourceObject.Load(6, &i.BytesWriteCancelled) +} + +func init() { + state.Register((*CPUStats)(nil)) + state.Register((*IO)(nil)) +} diff --git a/pkg/sentry/usage/usage_unsafe_state_autogen.go b/pkg/sentry/usage/usage_unsafe_state_autogen.go new file mode 100644 index 000000000..5d9e86efa --- /dev/null +++ b/pkg/sentry/usage/usage_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package usage diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD deleted file mode 100644 index 914574543..000000000 --- a/pkg/sentry/vfs/BUILD +++ /dev/null @@ -1,141 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_template_instance( - name = "epoll_interest_list", - out = "epoll_interest_list.go", - package = "vfs", - prefix = "epollInterest", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*epollInterest", - "Linker": "*epollInterest", - }, -) - -go_template_instance( - name = "event_list", - out = "event_list.go", - package = "vfs", - prefix = "event", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Event", - "Linker": "*Event", - }, -) - -go_template_instance( - name = "file_description_refs", - out = "file_description_refs.go", - package = "vfs", - prefix = "FileDescription", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "FileDescription", - }, -) - -go_template_instance( - name = "mount_namespace_refs", - out = "mount_namespace_refs.go", - package = "vfs", - prefix = "MountNamespace", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "MountNamespace", - }, -) - -go_template_instance( - name = "filesystem_refs", - out = "filesystem_refs.go", - package = "vfs", - prefix = "Filesystem", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "Filesystem", - }, -) - -go_library( - name = "vfs", - srcs = [ - "anonfs.go", - "context.go", - "debug.go", - "dentry.go", - "device.go", - "epoll.go", - "epoll_interest_list.go", - "event_list.go", - "file_description.go", - "file_description_impl_util.go", - "file_description_refs.go", - "filesystem.go", - "filesystem_impl_util.go", - "filesystem_refs.go", - "filesystem_type.go", - "inotify.go", - "lock.go", - "mount.go", - "mount_namespace_refs.go", - "mount_unsafe.go", - "opath.go", - "options.go", - "pathname.go", - "permissions.go", - "resolving_path.go", - "save_restore.go", - "vfs.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/fspath", - "//pkg/gohacks", - "//pkg/hostarch", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsmetric", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/uniqueid", - "//pkg/sync", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "vfs_test", - size = "small", - srcs = [ - "file_description_impl_util_test.go", - "mount_test.go", - ], - library = ":vfs", - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/sentry/contexttest", - "//pkg/sync", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md deleted file mode 100644 index 5aad31b78..000000000 --- a/pkg/sentry/vfs/README.md +++ /dev/null @@ -1,186 +0,0 @@ -# The gVisor Virtual Filesystem - -THIS PACKAGE IS CURRENTLY EXPERIMENTAL AND NOT READY OR ENABLED FOR PRODUCTION -USE. For the filesystem implementation currently used by gVisor, see the `fs` -package. - -## Implementation Notes - -### Reference Counting - -Filesystem, Dentry, Mount, MountNamespace, and FileDescription are all -reference-counted. Mount and MountNamespace are exclusively VFS-managed; when -their reference count reaches zero, VFS releases their resources. Filesystem and -FileDescription management is shared between VFS and filesystem implementations; -when their reference count reaches zero, VFS notifies the implementation by -calling `FilesystemImpl.Release()` or `FileDescriptionImpl.Release()` -respectively and then releases VFS-owned resources. Dentries are exclusively -managed by filesystem implementations; reference count changes are abstracted -through DentryImpl, which should release resources when reference count reaches -zero. - -Filesystem references are held by: - -- Mount: Each referenced Mount holds a reference on the mounted Filesystem. - -Dentry references are held by: - -- FileDescription: Each referenced FileDescription holds a reference on the - Dentry through which it was opened, via `FileDescription.vd.dentry`. - -- Mount: Each referenced Mount holds a reference on its mount point and on the - mounted filesystem root. The mount point is mutable (`mount(MS_MOVE)`). - -Mount references are held by: - -- FileDescription: Each referenced FileDescription holds a reference on the - Mount on which it was opened, via `FileDescription.vd.mount`. - -- Mount: Each referenced Mount holds a reference on its parent, which is the - mount containing its mount point. - -- VirtualFilesystem: A reference is held on each Mount that has been connected - to a mount point, but not yet umounted. - -MountNamespace and FileDescription references are held by users of VFS. The -expectation is that each `kernel.Task` holds a reference on its corresponding -MountNamespace, and each file descriptor holds a reference on its represented -FileDescription. - -Notes: - -- Dentries do not hold a reference on their owning Filesystem. Instead, all - uses of a Dentry occur in the context of a Mount, which holds a reference on - the relevant Filesystem (see e.g. the VirtualDentry type). As a corollary, - when releasing references on both a Dentry and its corresponding Mount, the - Dentry's reference must be released first (because releasing the Mount's - reference may release the last reference on the Filesystem, whose state may - be required to release the Dentry reference). - -### The Inheritance Pattern - -Filesystem, Dentry, and FileDescription are all concepts featuring both state -that must be shared between VFS and filesystem implementations, and operations -that are implementation-defined. To facilitate this, each of these three -concepts follows the same pattern, shown below for Dentry: - -```go -// Dentry represents a node in a filesystem tree. -type Dentry struct { - // VFS-required dentry state. - parent *Dentry - // ... - - // impl is the DentryImpl associated with this Dentry. impl is immutable. - // This should be the last field in Dentry. - impl DentryImpl -} - -// Init must be called before first use of d. -func (d *Dentry) Init(impl DentryImpl) { - d.impl = impl -} - -// Impl returns the DentryImpl associated with d. -func (d *Dentry) Impl() DentryImpl { - return d.impl -} - -// DentryImpl contains implementation-specific details of a Dentry. -// Implementations of DentryImpl should contain their associated Dentry by -// value as their first field. -type DentryImpl interface { - // VFS-required implementation-defined dentry operations. - IncRef() - // ... -} -``` - -This construction, which is essentially a type-safe analogue to Linux's -`container_of` pattern, has the following properties: - -- VFS works almost exclusively with pointers to Dentry rather than DentryImpl - interface objects, such as in the type of `Dentry.parent`. This avoids - interface method calls (which are somewhat expensive to perform, and defeat - inlining and escape analysis), reduces the size of VFS types (since an - interface object is two pointers in size), and allows pointers to be loaded - and stored atomically using `sync/atomic`. Implementation-defined behavior - is accessed via `Dentry.impl` when required. - -- Filesystem implementations can access the implementation-defined state - associated with objects of VFS types by type-asserting or type-switching - (e.g. `Dentry.Impl().(*myDentry)`). Type assertions to a concrete type - require only an equality comparison of the interface object's type pointer - to a static constant, and are consequently very fast. - -- Filesystem implementations can access the VFS state associated with objects - of implementation-defined types directly. - -- VFS and implementation-defined state for a given type occupy the same - object, minimizing memory allocations and maximizing memory locality. `impl` - is the last field in `Dentry`, and `Dentry` is the first field in - `DentryImpl` implementations, for similar reasons: this tends to cause - fetching of the `Dentry.impl` interface object to also fetch `DentryImpl` - fields, either because they are in the same cache line or via next-line - prefetching. - -## Future Work - -- Most `mount(2)` features, and unmounting, are incomplete. - -- VFS1 filesystems are not directly compatible with VFS2. It may be possible - to implement shims that implement `vfs.FilesystemImpl` for - `fs.MountNamespace`, `vfs.DentryImpl` for `fs.Dirent`, and - `vfs.FileDescriptionImpl` for `fs.File`, which may be adequate for - filesystems that are not performance-critical (e.g. sysfs); however, it is - not clear that this will be less effort than simply porting the filesystems - in question. Practically speaking, the following filesystems will probably - need to be ported or made compatible through a shim to evaluate filesystem - performance on realistic workloads: - - - devfs/procfs/sysfs, which will realistically be necessary to execute - most applications. (Note that procfs and sysfs do not support hard - links, so they do not require the complexity of separate inode objects. - Also note that Linux's /dev is actually a variant of tmpfs called - devtmpfs.) - - - tmpfs. This should be relatively straightforward: copy/paste memfs, - store regular file contents in pgalloc-allocated memory instead of - `[]byte`, and add support for file timestamps. (In fact, it probably - makes more sense to convert memfs to tmpfs and not keep the former.) - - - A remote filesystem, either lisafs (if it is ready by the time that - other benchmarking prerequisites are) or v9fs (aka 9P, aka gofers). - - - epoll files. - - Filesystems that will need to be ported before switching to VFS2, but can - probably be skipped for early testing: - - - overlayfs, which is needed for (at least) synthetic mount points. - - - Support for host ttys. - - - timerfd files. - - Filesystems that can be probably dropped: - - - ashmem, which is far too incomplete to use. - - - binder, which is similarly far too incomplete to use. - -- Save/restore. For instance, it is unclear if the current implementation of - the `state` package supports the inheritance pattern described above. - -- Many features that were previously implemented by VFS must now be - implemented by individual filesystems (though, in most cases, this should - consist of calls to hooks or libraries provided by `vfs` or other packages). - This includes, but is not necessarily limited to: - - - Block and character device special files - - - Inotify - - - File locking - - - `O_ASYNC` diff --git a/pkg/sentry/vfs/epoll_interest_list.go b/pkg/sentry/vfs/epoll_interest_list.go new file mode 100644 index 000000000..e75ea361b --- /dev/null +++ b/pkg/sentry/vfs/epoll_interest_list.go @@ -0,0 +1,221 @@ +package vfs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type epollInterestElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (epollInterestElementMapper) linkerFor(elem *epollInterest) *epollInterest { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type epollInterestList struct { + head *epollInterest + tail *epollInterest +} + +// Reset resets list l to the empty state. +func (l *epollInterestList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *epollInterestList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *epollInterestList) Front() *epollInterest { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *epollInterestList) Back() *epollInterest { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *epollInterestList) Len() (count int) { + for e := l.Front(); e != nil; e = (epollInterestElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *epollInterestList) PushFront(e *epollInterest) { + linker := epollInterestElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + epollInterestElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *epollInterestList) PushBack(e *epollInterest) { + linker := epollInterestElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + epollInterestElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *epollInterestList) PushBackList(m *epollInterestList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + epollInterestElementMapper{}.linkerFor(l.tail).SetNext(m.head) + epollInterestElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *epollInterestList) InsertAfter(b, e *epollInterest) { + bLinker := epollInterestElementMapper{}.linkerFor(b) + eLinker := epollInterestElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + epollInterestElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *epollInterestList) InsertBefore(a, e *epollInterest) { + aLinker := epollInterestElementMapper{}.linkerFor(a) + eLinker := epollInterestElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + epollInterestElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *epollInterestList) Remove(e *epollInterest) { + linker := epollInterestElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + epollInterestElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + epollInterestElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type epollInterestEntry struct { + next *epollInterest + prev *epollInterest +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *epollInterestEntry) Next() *epollInterest { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *epollInterestEntry) Prev() *epollInterest { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *epollInterestEntry) SetNext(elem *epollInterest) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *epollInterestEntry) SetPrev(elem *epollInterest) { + e.prev = elem +} diff --git a/pkg/sentry/vfs/event_list.go b/pkg/sentry/vfs/event_list.go new file mode 100644 index 000000000..c0946b585 --- /dev/null +++ b/pkg/sentry/vfs/event_list.go @@ -0,0 +1,221 @@ +package vfs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type eventElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (eventElementMapper) linkerFor(elem *Event) *Event { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type eventList struct { + head *Event + tail *Event +} + +// Reset resets list l to the empty state. +func (l *eventList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *eventList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *eventList) Front() *Event { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *eventList) Back() *Event { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *eventList) Len() (count int) { + for e := l.Front(); e != nil; e = (eventElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *eventList) PushFront(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + eventElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *eventList) PushBack(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *eventList) PushBackList(m *eventList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(m.head) + eventElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *eventList) InsertAfter(b, e *Event) { + bLinker := eventElementMapper{}.linkerFor(b) + eLinker := eventElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + eventElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *eventList) InsertBefore(a, e *Event) { + aLinker := eventElementMapper{}.linkerFor(a) + eLinker := eventElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + eventElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *eventList) Remove(e *Event) { + linker := eventElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + eventElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + eventElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type eventEntry struct { + next *Event + prev *Event +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *eventEntry) Next() *Event { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *eventEntry) Prev() *Event { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *eventEntry) SetNext(elem *Event) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *eventEntry) SetPrev(elem *Event) { + e.prev = elem +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go deleted file mode 100644 index e34a8c11b..000000000 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "bytes" - "fmt" - "io" - "sync/atomic" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/usermem" -) - -// fileDescription is the common fd struct which a filesystem implementation -// embeds in all of its file description implementations as required. -type fileDescription struct { - vfsfd FileDescription - FileDescriptionDefaultImpl - NoLockFD -} - -// genCount contains the number of times its DynamicBytesSource.Generate() -// implementation has been called. -type genCount struct { - count uint64 // accessed using atomic memory ops -} - -// Generate implements DynamicBytesSource.Generate. -func (g *genCount) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "%d", atomic.AddUint64(&g.count, 1)) - return nil -} - -type storeData struct { - data string -} - -var _ WritableDynamicBytesSource = (*storeData)(nil) - -// Generate implements DynamicBytesSource. -func (d *storeData) Generate(ctx context.Context, buf *bytes.Buffer) error { - buf.WriteString(d.data) - return nil -} - -// Generate implements WritableDynamicBytesSource. -func (d *storeData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { - buf := make([]byte, src.NumBytes()) - n, err := src.CopyIn(ctx, buf) - if err != nil { - return 0, err - } - - d.data = string(buf[:n]) - return 0, nil -} - -// testFD is a read-only FileDescriptionImpl representing a regular file. -type testFD struct { - fileDescription - DynamicBytesFileDescriptionImpl - - data DynamicBytesSource -} - -func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription { - vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef(ctx) - var fd testFD - fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}) - fd.DynamicBytesFileDescriptionImpl.SetDataSource(data) - return &fd.vfsfd -} - -// Release implements FileDescriptionImpl.Release. -func (fd *testFD) Release(context.Context) { -} - -// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags. -// Stat implements FileDescriptionImpl.Stat. -func (fd *testFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { - // Note that Statx.Mask == 0 in the return value. - return linux.Statx{}, nil -} - -// SetStat implements FileDescriptionImpl.SetStat. -func (fd *testFD) SetStat(ctx context.Context, opts SetStatOptions) error { - return linuxerr.EPERM -} - -func TestGenCountFD(t *testing.T) { - ctx := contexttest.Context(t) - - vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &genCount{}) - defer fd.DecRef(ctx) - - // The first read causes Generate to be called to fill the FD's buffer. - buf := make([]byte, 2) - ioseq := usermem.BytesIOSequence(buf) - n, err := fd.Read(ctx, ioseq, ReadOptions{}) - if n != 1 || (err != nil && err != io.EOF) { - t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) - } - if want := byte('1'); buf[0] != want { - t.Errorf("first Read: got byte %c, wanted %c", buf[0], want) - } - - // A second read without seeking is still at EOF. - n, err = fd.Read(ctx, ioseq, ReadOptions{}) - if n != 0 || err != io.EOF { - t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) - } - - // Seeking to the beginning of the file causes it to be regenerated. - n, err = fd.Seek(ctx, 0, linux.SEEK_SET) - if n != 0 || err != nil { - t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) - } - n, err = fd.Read(ctx, ioseq, ReadOptions{}) - if n != 1 || (err != nil && err != io.EOF) { - t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) - } - if want := byte('2'); buf[0] != want { - t.Errorf("Read after Seek: got byte %c, wanted %c", buf[0], want) - } - - // PRead at the beginning of the file also causes it to be regenerated. - n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{}) - if n != 1 || (err != nil && err != io.EOF) { - t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) - } - if want := byte('3'); buf[0] != want { - t.Errorf("PRead: got byte %c, wanted %c", buf[0], want) - } - - // Write and PWrite fails. - if _, err := fd.Write(ctx, ioseq, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("Write: got err %v, wanted %v", err, linuxerr.EIO) - } - if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { - t.Errorf("Write: got err %v, wanted %v", err, linuxerr.EIO) - } -} - -func TestWritable(t *testing.T) { - ctx := contexttest.Context(t) - - vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(ctx); err != nil { - t.Fatalf("VFS init: %v", err) - } - fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &storeData{data: "init"}) - defer fd.DecRef(ctx) - - buf := make([]byte, 10) - ioseq := usermem.BytesIOSequence(buf) - if n, err := fd.Read(ctx, ioseq, ReadOptions{}); n != 4 && err != io.EOF { - t.Fatalf("Read: got (%v, %v), wanted (4, EOF)", n, err) - } - if want := "init"; want == string(buf) { - t.Fatalf("Read: got %v, wanted %v", string(buf), want) - } - - // Test PWrite. - want := "write" - writeIOSeq := usermem.BytesIOSequence([]byte(want)) - if n, err := fd.PWrite(ctx, writeIOSeq, 0, WriteOptions{}); int(n) != len(want) && err != nil { - t.Errorf("PWrite: got err (%v, %v), wanted (%v, nil)", n, err, len(want)) - } - if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF { - t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want)) - } - if want == string(buf) { - t.Fatalf("PRead: got %v, wanted %v", string(buf), want) - } - - // Test Seek to 0 followed by Write. - want = "write2" - writeIOSeq = usermem.BytesIOSequence([]byte(want)) - if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 && err != nil { - t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err) - } - if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); int(n) != len(want) && err != nil { - t.Errorf("Write: got err (%v, %v), wanted (%v, nil)", n, err, len(want)) - } - if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF { - t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want)) - } - if want == string(buf) { - t.Fatalf("PRead: got %v, wanted %v", string(buf), want) - } - - // Test failure if offset != 0. - if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil { - t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err) - } - if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) { - t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err) - } - if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) { - t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err) - } -} diff --git a/pkg/sentry/vfs/file_description_refs.go b/pkg/sentry/vfs/file_description_refs.go new file mode 100644 index 000000000..9e6b7bd40 --- /dev/null +++ b/pkg/sentry/vfs/file_description_refs.go @@ -0,0 +1,140 @@ +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const FileDescriptionenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var FileDescriptionobj *FileDescription + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type FileDescriptionRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *FileDescriptionRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *FileDescriptionRefs) RefType() string { + return fmt.Sprintf("%T", FileDescriptionobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *FileDescriptionRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *FileDescriptionRefs) LogRefs() bool { + return FileDescriptionenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *FileDescriptionRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *FileDescriptionRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if FileDescriptionenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *FileDescriptionRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if FileDescriptionenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *FileDescriptionRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if FileDescriptionenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *FileDescriptionRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/vfs/filesystem_refs.go b/pkg/sentry/vfs/filesystem_refs.go new file mode 100644 index 000000000..fc47919d0 --- /dev/null +++ b/pkg/sentry/vfs/filesystem_refs.go @@ -0,0 +1,140 @@ +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const FilesystemenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var Filesystemobj *Filesystem + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type FilesystemRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *FilesystemRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *FilesystemRefs) RefType() string { + return fmt.Sprintf("%T", Filesystemobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *FilesystemRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *FilesystemRefs) LogRefs() bool { + return FilesystemenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *FilesystemRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *FilesystemRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if FilesystemenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *FilesystemRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if FilesystemenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *FilesystemRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if FilesystemenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *FilesystemRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md deleted file mode 100644 index 833db213f..000000000 --- a/pkg/sentry/vfs/g3doc/inotify.md +++ /dev/null @@ -1,210 +0,0 @@ -# Inotify - -Inotify is a mechanism for monitoring filesystem events in Linux--see -inotify(7). An inotify instance can be used to monitor files and directories for -modifications, creation/deletion, etc. The inotify API consists of system calls -that create inotify instances (inotify_init/inotify_init1) and add/remove -watches on files to an instance (inotify_add_watch/inotify_rm_watch). Events are -generated from various places in the sentry, including the syscall layer, the -vfs layer, the process fd table, and within each filesystem implementation. This -document outlines the implementation details of inotify in VFS2. - -## Inotify Objects - -Inotify data structures are implemented in the vfs package. - -### vfs.Inotify - -Inotify instances are represented by vfs.Inotify objects, which implement -vfs.FileDescriptionImpl. As in Linux, inotify fds are backed by a -pseudo-filesystem (anonfs). Each inotify instance receives events from a set of -vfs.Watch objects, which can be modified with inotify_add_watch(2) and -inotify_rm_watch(2). An application can retrieve events by reading the inotify -fd. - -### vfs.Watches - -The set of all watches held on a single file (i.e., the watch target) is stored -in vfs.Watches. Each watch will belong to a different inotify instance (an -instance can only have one watch on any watch target). The watches are stored in -a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions -to a single file will all share the same vfs.Watches (with the exception of the -gofer filesystem, described in a later section). Activity on the target causes -its vfs.Watches to generate notifications on its watches’ inotify instances. - -### vfs.Watch - -A single watch, owned by one inotify instance and applied to one watch target. -Both the vfs.Inotify owner and vfs.Watches on the target will hold a vfs.Watch, -which leads to some complicated locking behavior (see Lock Ordering). Whenever a -watch is notified of an event on its target, it will queue events to its inotify -instance for delivery to the user. - -### vfs.Event - -vfs.Event is a simple struct encapsulating all the fields for an inotify event. -It is generated by vfs.Watches and forwarded to the watches' owners. It is -serialized to the user during read(2) syscalls on the associated fs.Inotify's -fd. - -## Lock Ordering - -There are three locks related to the inotify implementation: - -Inotify.mu: the inotify instance lock. Inotify.evMu: the inotify event queue -lock. Watches.mu: the watch set lock, used to protect the collection of watches -on a target. - -The correct lock ordering for inotify code is: - -Inotify.mu -> Watches.mu -> Inotify.evMu. - -Note that we use a distinct lock to protect the inotify event queue. If we -simply used Inotify.mu, we could simultaneously have locks being acquired in the -order of Inotify.mu -> Watches.mu and Watches.mu -> Inotify.mu, which would -cause deadlocks. For instance, adding a watch to an inotify instance would -require locking Inotify.mu, and then adding the same watch to the target would -cause Watches.mu to be held. At the same time, generating an event on the target -would require Watches.mu to be held before iterating through each watch, and -then notifying the owner of each watch would cause Inotify.mu to be held. - -See the vfs package comment to understand how inotify locks fit into the overall -ordering of filesystem locks. - -## Watch Targets in Different Filesystem Implementations - -In Linux, watches reside on inodes at the virtual filesystem layer. As a result, -all hard links and file descriptions on a single file will all share the same -watch set. In VFS2, there is no common inode structure across filesystem types -(some may not even have inodes), so we have to plumb inotify support through -each specific filesystem implementation. Some of the technical considerations -are outlined below. - -### Tmpfs - -For filesystems with inodes, like tmpfs, the design is quite similar to that of -Linux, where watches reside on the inode. - -### Pseudo-filesystems - -Technically, because inotify is implemented at the vfs layer in Linux, -pseudo-filesystems on top of kernfs support inotify passively. However, watches -can only track explicit filesystem operations like read/write, open/close, -mknod, etc., so watches on a target like /proc/self/fd will not generate events -every time a new fd is added or removed. As of this writing, we leave inotify -unimplemented in kernfs and anonfs; it does not seem particularly useful. - -### Gofer Filesystem (fsimpl/gofer) - -The gofer filesystem has several traits that make it difficult to support -inotify: - -* **There are no inodes.** A file is represented as a dentry that holds an - unopened p9 file (and possibly an open FID), through which the Sentry - interacts with the gofer. - * *Solution:* Because there is no inode structure stored in the sandbox, - inotify watches must be held on the dentry. For the purposes of inotify, - we assume that every dentry corresponds to a unique inode, which may - cause unexpected behavior in the presence of hard links, where multiple - dentries should share the same set of watches. Indeed, it is impossible - for us to be absolutely sure whether dentries correspond to the same - file or not, due to the following point: -* **The Sentry cannot always be aware of hard links on the remote - filesystem.** There is no way for us to confirm whether two files on the - remote filesystem are actually links to the same inode. QIDs and inodes are - not always 1:1. The assumption that dentries and inodes are 1:1 is - inevitably broken if there are remote hard links that we cannot detect. - * *Solution:* this is an issue with gofer fs in general, not only inotify, - and we will have to live with it. -* **Dentries can be cached, and then evicted.** Dentry lifetime does not - correspond to file lifetime. Because gofer fs is not entirely in-memory, the - absence of a dentry does not mean that the corresponding file does not - exist, nor does a dentry reaching zero references mean that the - corresponding file no longer exists. When a dentry reaches zero references, - it will be cached, in case the file at that path is needed again in the - future. However, the dentry may be evicted from the cache, which will cause - a new dentry to be created next time the same file path is used. The - existing watches will be lost. - * *Solution:* When a dentry reaches zero references, do not cache it if it - has any watches, so we can avoid eviction/destruction. Note that if the - dentry was deleted or invalidated (d.vfsd.IsDead()), we should still - destroy it along with its watches. Additionally, when a dentry’s last - watch is removed, we cache it if it also has zero references. This way, - the dentry can eventually be evicted from memory if it is no longer - needed. -* **Dentries can be invalidated.** Another issue with dentry lifetime is that - the remote file at the file path represented may change from underneath the - dentry. In this case, the next time that the dentry is used, it will be - invalidated and a new dentry will replace it. In this case, it is not clear - what should be done with the watches on the old dentry. - * *Solution:* Silently destroy the watches when invalidation occurs. We - have no way of knowing exactly what happened, when it happens. Inotify - instances on NFS files in Linux probably behave in a similar fashion, - since inotify is implemented at the vfs layer and is not aware of the - complexities of remote file systems. - * An alternative would be to issue some kind of event upon invalidation, - e.g. a delete event, but this has several issues: - * We cannot discern whether the remote file was invalidated because it was - moved, deleted, etc. This information is crucial, because these cases - should result in different events. Furthermore, the watches should only - be destroyed if the file has been deleted. - * Moreover, the mechanism for detecting whether the underlying file has - changed is to check whether a new QID is given by the gofer. This may - result in false positives, e.g. suppose that the server closed and - re-opened the same file, which may result in a new QID. - * Finally, the time of the event may be completely different from the time - of the file modification, since a dentry is not immediately notified - when the underlying file has changed. It would be quite unexpected to - receive the notification when invalidation was triggered, i.e. the next - time the file was accessed within the sandbox, because then the - read/write/etc. operation on the file would not result in the expected - event. - * Another point in favor of the first solution: inotify in Linux can - already be lossy on local filesystems (one of the sacrifices made so - that filesystem performance isn’t killed), and it is lossy on NFS for - similar reasons to gofer fs. Therefore, it is better for inotify to be - silent than to emit incorrect notifications. -* **There may be external users of the remote filesystem.** We can only track - operations performed on the file within the sandbox. This is sufficient - under InteropModeExclusive, but whenever there are external users, the set - of actions we are aware of is incomplete. - * *Solution:* We could either return an error or just issue a warning when - inotify is used without InteropModeExclusive. Although faulty, VFS1 - allows it when the filesystem is shared, and Linux does the same for - remote filesystems (as mentioned above, inotify sits at the vfs level). - -## Dentry Interface - -For events that must be generated above the vfs layer, we provide the following -DentryImpl methods to allow interactions with targets on any FilesystemImpl: - -* **InotifyWithParent()** generates events on the dentry’s watches as well as - its parent’s. -* **Watches()** retrieves the watch set of the target represented by the - dentry. This is used to access and modify watches on a target. -* **OnZeroWatches()** performs cleanup tasks after the last watch is removed - from a dentry. This is needed by gofer fs, which must allow a watched dentry - to be cached once it has no more watches. Most implementations can just do - nothing. Note that OnZeroWatches() must be called after all inotify locks - are released to preserve lock ordering, since it may acquire - FilesystemImpl-specific locks. - -## IN_EXCL_UNLINK - -There are several options that can be set for a watch, specified as part of the -mask in inotify_add_watch(2). In particular, IN_EXCL_UNLINK requires some -additional support in each filesystem. - -A watch with IN_EXCL_UNLINK will not generate events for its target if it -corresponds to a path that was unlinked. For instance, if an fd is opened on -“foo/bar” and “foo/bar” is subsequently unlinked, any reads/writes/etc. on the -fd will be ignored by watches on “foo” or “foo/bar” with IN_EXCL_UNLINK. This -requires each DentryImpl to keep track of whether it has been unlinked, in order -to determine whether events should be sent to watches with IN_EXCL_UNLINK. - -## IN_ONESHOT - -One-shot watches expire after generating a single event. When an event occurs, -all one-shot watches on the target that successfully generated an event are -removed. Lock ordering can cause the management of one-shot watches to be quite -expensive; see Watches.Notify() for more information. diff --git a/pkg/sentry/vfs/genericfstree/BUILD b/pkg/sentry/vfs/genericfstree/BUILD deleted file mode 100644 index d8fd92677..000000000 --- a/pkg/sentry/vfs/genericfstree/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -go_template( - name = "generic_fstree", - srcs = [ - "genericfstree.go", - ], - types = [ - "Dentry", - ], -) diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go deleted file mode 100644 index ba6e6ed49..000000000 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package genericfstree provides tools for implementing vfs.FilesystemImpls -// where a single statically-determined lock or set of locks is sufficient to -// ensure that a Dentry's name and parent are contextually immutable. -// -// Clients using this package must use the go_template_instance rule in -// tools/go_generics/defs.bzl to create an instantiation of this template -// package, providing types to use in place of Dentry. -package genericfstree - -import ( - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/vfs" -) - -// Dentry is a required type parameter that is a struct with the given fields. -// -// +stateify savable -type Dentry struct { - // vfsd is the embedded vfs.Dentry corresponding to this vfs.DentryImpl. - vfsd vfs.Dentry - - // parent is the parent of this Dentry in the filesystem's tree. If this - // Dentry is a filesystem root, parent is nil. - parent *Dentry - - // name is the name of this Dentry in its parent. If this Dentry is a - // filesystem root, name is unspecified. - name string -} - -// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is -// either d2's parent or an ancestor of d2's parent. -func IsAncestorDentry(d, d2 *Dentry) bool { - for d2 != nil { // Stop at root, where d2.parent == nil. - if d2.parent == d { - return true - } - if d2.parent == d2 { - return false - } - d2 = d2.parent - } - return false -} - -// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. -func ParentOrSelf(d *Dentry) *Dentry { - if d.parent != nil { - return d.parent - } - return d -} - -// PrependPath is a generic implementation of FilesystemImpl.PrependPath(). -func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath.Builder) error { - for { - if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { - return vfs.PrependPathAtVFSRootError{} - } - if mnt != nil && &d.vfsd == mnt.Root() { - return nil - } - if d.parent == nil { - return vfs.PrependPathAtNonMountRootError{} - } - b.PrependComponent(d.name) - d = d.parent - } -} - -// DebugPathname returns a pathname to d relative to its filesystem root. -// DebugPathname does not correspond to any Linux function; it's used to -// generate dentry pathnames for debugging. -func DebugPathname(d *Dentry) string { - var b fspath.Builder - _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b) - return b.String() -} diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD deleted file mode 100644 index 444ab42b9..000000000 --- a/pkg/sentry/vfs/memxattr/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "memxattr", - srcs = ["xattr.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/errors/linuxerr", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/sync", - ], -) diff --git a/pkg/sentry/vfs/memxattr/memxattr_state_autogen.go b/pkg/sentry/vfs/memxattr/memxattr_state_autogen.go new file mode 100644 index 000000000..f75223725 --- /dev/null +++ b/pkg/sentry/vfs/memxattr/memxattr_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package memxattr + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SimpleExtendedAttributes) StateTypeName() string { + return "pkg/sentry/vfs/memxattr.SimpleExtendedAttributes" +} + +func (x *SimpleExtendedAttributes) StateFields() []string { + return []string{ + "xattrs", + } +} + +func (x *SimpleExtendedAttributes) beforeSave() {} + +// +checklocksignore +func (x *SimpleExtendedAttributes) StateSave(stateSinkObject state.Sink) { + x.beforeSave() + stateSinkObject.Save(0, &x.xattrs) +} + +func (x *SimpleExtendedAttributes) afterLoad() {} + +// +checklocksignore +func (x *SimpleExtendedAttributes) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &x.xattrs) +} + +func init() { + state.Register((*SimpleExtendedAttributes)(nil)) +} diff --git a/pkg/sentry/vfs/mount_namespace_refs.go b/pkg/sentry/vfs/mount_namespace_refs.go new file mode 100644 index 000000000..176733505 --- /dev/null +++ b/pkg/sentry/vfs/mount_namespace_refs.go @@ -0,0 +1,140 @@ +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const MountNamespaceenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var MountNamespaceobj *MountNamespace + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type MountNamespaceRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *MountNamespaceRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *MountNamespaceRefs) RefType() string { + return fmt.Sprintf("%T", MountNamespaceobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *MountNamespaceRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *MountNamespaceRefs) LogRefs() bool { + return MountNamespaceenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *MountNamespaceRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *MountNamespaceRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if MountNamespaceenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *MountNamespaceRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if MountNamespaceenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *MountNamespaceRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if MountNamespaceenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *MountNamespaceRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go deleted file mode 100644 index cb882a983..000000000 --- a/pkg/sentry/vfs/mount_test.go +++ /dev/null @@ -1,467 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "fmt" - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestMountTableLookupEmpty(t *testing.T) { - var mt mountTable - mt.Init() - - parent := &Mount{} - point := &Dentry{} - if m := mt.Lookup(parent, point); m != nil { - t.Errorf("Empty mountTable lookup: got %p, wanted nil", m) - } -} - -func TestMountTableInsertLookup(t *testing.T) { - var mt mountTable - mt.Init() - - mount := &Mount{} - mount.setKey(VirtualDentry{&Mount{}, &Dentry{}}) - mt.Insert(mount) - - if m := mt.Lookup(mount.parent(), mount.point()); m != mount { - t.Errorf("mountTable positive lookup: got %p, wanted %p", m, mount) - } - - otherParent := &Mount{} - if m := mt.Lookup(otherParent, mount.point()); m != nil { - t.Errorf("mountTable lookup with wrong mount parent: got %p, wanted nil", m) - } - otherPoint := &Dentry{} - if m := mt.Lookup(mount.parent(), otherPoint); m != nil { - t.Errorf("mountTable lookup with wrong mount point: got %p, wanted nil", m) - } -} - -// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal. - -// must be powers of 2 -var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8} - -// For all of the following: -// -// - BenchmarkMountTableFoo tests usage pattern "Foo" for mountTable. -// -// - BenchmarkMountMapFoo tests usage pattern "Foo" for a -// sync.RWMutex-protected map. (Mutator benchmarks do not use a RWMutex, since -// mountTable also requires external synchronization between mutators.) -// -// - BenchmarkMountSyncMapFoo tests usage pattern "Foo" for a sync.Map. -// -// ParallelLookup is by far the most common and performance-sensitive operation -// for this application. NegativeLookup is also important, but less so (only -// relevant with multiple mount namespaces and significant differences in -// mounts between them). Insertion and removal are benchmarked for -// completeness. -const enableComparativeBenchmarks = false - -func newBenchMount() *Mount { - mount := &Mount{} - mount.loadKey(VirtualDentry{&Mount{}, &Dentry{}}) - return mount -} - -func BenchmarkMountTableParallelLookup(b *testing.B) { - for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%dx%d", numG, numMounts) - b.Run(desc, func(b *testing.B) { - var mt mountTable - mt.Init() - keys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - mt.Insert(mount) - keys = append(keys, mount.saveKey()) - } - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for g := 0; g < numG; g++ { - ready.Add(1) - end.Add(1) - go func() { - defer end.Done() - ready.Done() - <-begin - for i := 0; i < b.N; i++ { - k := keys[i&(numMounts-1)] - m := mt.Lookup(k.mount, k.dentry) - if m == nil { - b.Errorf("Lookup failed") - return - } - if parent := m.parent(); parent != k.mount { - b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) - return - } - if point := m.point(); point != k.dentry { - b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) - return - } - } - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } - } -} - -func BenchmarkMountMapParallelLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%dx%d", numG, numMounts) - b.Run(desc, func(b *testing.B) { - var mu sync.RWMutex - ms := make(map[VirtualDentry]*Mount) - keys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - key := mount.saveKey() - ms[key] = mount - keys = append(keys, key) - } - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for g := 0; g < numG; g++ { - ready.Add(1) - end.Add(1) - go func() { - defer end.Done() - ready.Done() - <-begin - for i := 0; i < b.N; i++ { - k := keys[i&(numMounts-1)] - mu.RLock() - m := ms[k] - mu.RUnlock() - if m == nil { - b.Errorf("Lookup failed") - return - } - if parent := m.parent(); parent != k.mount { - b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) - return - } - if point := m.point(); point != k.dentry { - b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) - return - } - } - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } - } -} - -func BenchmarkMountSyncMapParallelLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%dx%d", numG, numMounts) - b.Run(desc, func(b *testing.B) { - var ms sync.Map - keys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - key := mount.getKey() - ms.Store(key, mount) - keys = append(keys, key) - } - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for g := 0; g < numG; g++ { - ready.Add(1) - end.Add(1) - go func() { - defer end.Done() - ready.Done() - <-begin - for i := 0; i < b.N; i++ { - k := keys[i&(numMounts-1)] - mi, ok := ms.Load(k) - if !ok { - b.Errorf("Lookup failed") - return - } - m := mi.(*Mount) - if parent := m.parent(); parent != k.mount { - b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) - return - } - if point := m.point(); point != k.dentry { - b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) - return - } - } - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } - } -} - -func BenchmarkMountTableNegativeLookup(b *testing.B) { - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%d", numMounts) - b.Run(desc, func(b *testing.B) { - var mt mountTable - mt.Init() - for i := 0; i < numMounts; i++ { - mt.Insert(newBenchMount()) - } - negkeys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - negkeys = append(negkeys, VirtualDentry{ - mount: &Mount{}, - dentry: &Dentry{}, - }) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - k := negkeys[i&(numMounts-1)] - m := mt.Lookup(k.mount, k.dentry) - if m != nil { - b.Fatalf("Lookup got %p, wanted nil", m) - } - } - }) - } -} - -func BenchmarkMountMapNegativeLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%d", numMounts) - b.Run(desc, func(b *testing.B) { - var mu sync.RWMutex - ms := make(map[VirtualDentry]*Mount) - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - ms[mount.getKey()] = mount - } - negkeys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - negkeys = append(negkeys, VirtualDentry{ - mount: &Mount{}, - dentry: &Dentry{}, - }) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - k := negkeys[i&(numMounts-1)] - mu.RLock() - m := ms[k] - mu.RUnlock() - if m != nil { - b.Fatalf("Lookup got %p, wanted nil", m) - } - } - }) - } -} - -func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - for _, numMounts := range benchNumMounts { - desc := fmt.Sprintf("%d", numMounts) - b.Run(desc, func(b *testing.B) { - var ms sync.Map - for i := 0; i < numMounts; i++ { - mount := newBenchMount() - ms.Store(mount.saveKey(), mount) - } - negkeys := make([]VirtualDentry, 0, numMounts) - for i := 0; i < numMounts; i++ { - negkeys = append(negkeys, VirtualDentry{ - mount: &Mount{}, - dentry: &Dentry{}, - }) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - k := negkeys[i&(numMounts-1)] - m, _ := ms.Load(k) - if m != nil { - b.Fatalf("Lookup got %p, wanted nil", m) - } - } - }) - } -} - -func BenchmarkMountTableInsert(b *testing.B) { - // Preallocate Mounts so that allocation time isn't included in the - // benchmark. - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - - var mt mountTable - mt.Init() - b.ResetTimer() - for i := range mounts { - mt.Insert(mounts[i]) - } -} - -func BenchmarkMountMapInsert(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - // Preallocate Mounts so that allocation time isn't included in the - // benchmark. - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - - ms := make(map[VirtualDentry]*Mount) - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - ms[mount.saveKey()] = mount - } -} - -func BenchmarkMountSyncMapInsert(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - // Preallocate Mounts so that allocation time isn't included in the - // benchmark. - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - - var ms sync.Map - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - ms.Store(mount.saveKey(), mount) - } -} - -func BenchmarkMountTableRemove(b *testing.B) { - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - var mt mountTable - mt.Init() - for i := range mounts { - mt.Insert(mounts[i]) - } - - b.ResetTimer() - for i := range mounts { - mt.Remove(mounts[i]) - } -} - -func BenchmarkMountMapRemove(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - ms := make(map[VirtualDentry]*Mount) - for i := range mounts { - mount := mounts[i] - ms[mount.saveKey()] = mount - } - - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - delete(ms, mount.saveKey()) - } -} - -func BenchmarkMountSyncMapRemove(b *testing.B) { - if !enableComparativeBenchmarks { - b.Skipf("comparative benchmarks are disabled") - } - - mounts := make([]*Mount, 0, b.N) - for i := 0; i < b.N; i++ { - mounts = append(mounts, newBenchMount()) - } - var ms sync.Map - for i := range mounts { - mount := mounts[i] - ms.Store(mount.saveKey(), mount) - } - - b.ResetTimer() - for i := range mounts { - mount := mounts[i] - ms.Delete(mount.saveKey()) - } -} diff --git a/pkg/sentry/vfs/vfs_state_autogen.go b/pkg/sentry/vfs/vfs_state_autogen.go new file mode 100644 index 000000000..a460f4055 --- /dev/null +++ b/pkg/sentry/vfs/vfs_state_autogen.go @@ -0,0 +1,2055 @@ +// automatically generated by stateify. + +package vfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (a *anonFilesystemType) StateTypeName() string { + return "pkg/sentry/vfs.anonFilesystemType" +} + +func (a *anonFilesystemType) StateFields() []string { + return []string{} +} + +func (a *anonFilesystemType) beforeSave() {} + +// +checklocksignore +func (a *anonFilesystemType) StateSave(stateSinkObject state.Sink) { + a.beforeSave() +} + +func (a *anonFilesystemType) afterLoad() {} + +// +checklocksignore +func (a *anonFilesystemType) StateLoad(stateSourceObject state.Source) { +} + +func (fs *anonFilesystem) StateTypeName() string { + return "pkg/sentry/vfs.anonFilesystem" +} + +func (fs *anonFilesystem) StateFields() []string { + return []string{ + "vfsfs", + "devMinor", + } +} + +func (fs *anonFilesystem) beforeSave() {} + +// +checklocksignore +func (fs *anonFilesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.vfsfs) + stateSinkObject.Save(1, &fs.devMinor) +} + +func (fs *anonFilesystem) afterLoad() {} + +// +checklocksignore +func (fs *anonFilesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.vfsfs) + stateSourceObject.Load(1, &fs.devMinor) +} + +func (d *anonDentry) StateTypeName() string { + return "pkg/sentry/vfs.anonDentry" +} + +func (d *anonDentry) StateFields() []string { + return []string{ + "vfsd", + "name", + } +} + +func (d *anonDentry) beforeSave() {} + +// +checklocksignore +func (d *anonDentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.vfsd) + stateSinkObject.Save(1, &d.name) +} + +func (d *anonDentry) afterLoad() {} + +// +checklocksignore +func (d *anonDentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.vfsd) + stateSourceObject.Load(1, &d.name) +} + +func (d *Dentry) StateTypeName() string { + return "pkg/sentry/vfs.Dentry" +} + +func (d *Dentry) StateFields() []string { + return []string{ + "dead", + "mounts", + "impl", + } +} + +func (d *Dentry) beforeSave() {} + +// +checklocksignore +func (d *Dentry) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.dead) + stateSinkObject.Save(1, &d.mounts) + stateSinkObject.Save(2, &d.impl) +} + +func (d *Dentry) afterLoad() {} + +// +checklocksignore +func (d *Dentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.dead) + stateSourceObject.Load(1, &d.mounts) + stateSourceObject.Load(2, &d.impl) +} + +func (kind *DeviceKind) StateTypeName() string { + return "pkg/sentry/vfs.DeviceKind" +} + +func (kind *DeviceKind) StateFields() []string { + return nil +} + +func (d *devTuple) StateTypeName() string { + return "pkg/sentry/vfs.devTuple" +} + +func (d *devTuple) StateFields() []string { + return []string{ + "kind", + "major", + "minor", + } +} + +func (d *devTuple) beforeSave() {} + +// +checklocksignore +func (d *devTuple) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.kind) + stateSinkObject.Save(1, &d.major) + stateSinkObject.Save(2, &d.minor) +} + +func (d *devTuple) afterLoad() {} + +// +checklocksignore +func (d *devTuple) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.kind) + stateSourceObject.Load(1, &d.major) + stateSourceObject.Load(2, &d.minor) +} + +func (r *registeredDevice) StateTypeName() string { + return "pkg/sentry/vfs.registeredDevice" +} + +func (r *registeredDevice) StateFields() []string { + return []string{ + "dev", + "opts", + } +} + +func (r *registeredDevice) beforeSave() {} + +// +checklocksignore +func (r *registeredDevice) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.dev) + stateSinkObject.Save(1, &r.opts) +} + +func (r *registeredDevice) afterLoad() {} + +// +checklocksignore +func (r *registeredDevice) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.dev) + stateSourceObject.Load(1, &r.opts) +} + +func (r *RegisterDeviceOptions) StateTypeName() string { + return "pkg/sentry/vfs.RegisterDeviceOptions" +} + +func (r *RegisterDeviceOptions) StateFields() []string { + return []string{ + "GroupName", + } +} + +func (r *RegisterDeviceOptions) beforeSave() {} + +// +checklocksignore +func (r *RegisterDeviceOptions) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.GroupName) +} + +func (r *RegisterDeviceOptions) afterLoad() {} + +// +checklocksignore +func (r *RegisterDeviceOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.GroupName) +} + +func (ep *EpollInstance) StateTypeName() string { + return "pkg/sentry/vfs.EpollInstance" +} + +func (ep *EpollInstance) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "q", + "interest", + "ready", + } +} + +func (ep *EpollInstance) beforeSave() {} + +// +checklocksignore +func (ep *EpollInstance) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + stateSinkObject.Save(0, &ep.vfsfd) + stateSinkObject.Save(1, &ep.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &ep.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &ep.NoLockFD) + stateSinkObject.Save(4, &ep.q) + stateSinkObject.Save(5, &ep.interest) + stateSinkObject.Save(6, &ep.ready) +} + +func (ep *EpollInstance) afterLoad() {} + +// +checklocksignore +func (ep *EpollInstance) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.vfsfd) + stateSourceObject.Load(1, &ep.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &ep.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &ep.NoLockFD) + stateSourceObject.Load(4, &ep.q) + stateSourceObject.Load(5, &ep.interest) + stateSourceObject.Load(6, &ep.ready) +} + +func (e *epollInterestKey) StateTypeName() string { + return "pkg/sentry/vfs.epollInterestKey" +} + +func (e *epollInterestKey) StateFields() []string { + return []string{ + "file", + "num", + } +} + +func (e *epollInterestKey) beforeSave() {} + +// +checklocksignore +func (e *epollInterestKey) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.file) + stateSinkObject.Save(1, &e.num) +} + +func (e *epollInterestKey) afterLoad() {} + +// +checklocksignore +func (e *epollInterestKey) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.file) + stateSourceObject.Load(1, &e.num) +} + +func (epi *epollInterest) StateTypeName() string { + return "pkg/sentry/vfs.epollInterest" +} + +func (epi *epollInterest) StateFields() []string { + return []string{ + "epoll", + "key", + "waiter", + "mask", + "ready", + "epollInterestEntry", + "userData", + } +} + +func (epi *epollInterest) beforeSave() {} + +// +checklocksignore +func (epi *epollInterest) StateSave(stateSinkObject state.Sink) { + epi.beforeSave() + stateSinkObject.Save(0, &epi.epoll) + stateSinkObject.Save(1, &epi.key) + stateSinkObject.Save(2, &epi.waiter) + stateSinkObject.Save(3, &epi.mask) + stateSinkObject.Save(4, &epi.ready) + stateSinkObject.Save(5, &epi.epollInterestEntry) + stateSinkObject.Save(6, &epi.userData) +} + +// +checklocksignore +func (epi *epollInterest) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &epi.epoll) + stateSourceObject.Load(1, &epi.key) + stateSourceObject.Load(2, &epi.waiter) + stateSourceObject.Load(3, &epi.mask) + stateSourceObject.Load(4, &epi.ready) + stateSourceObject.Load(5, &epi.epollInterestEntry) + stateSourceObject.Load(6, &epi.userData) + stateSourceObject.AfterLoad(epi.afterLoad) +} + +func (l *epollInterestList) StateTypeName() string { + return "pkg/sentry/vfs.epollInterestList" +} + +func (l *epollInterestList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *epollInterestList) beforeSave() {} + +// +checklocksignore +func (l *epollInterestList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *epollInterestList) afterLoad() {} + +// +checklocksignore +func (l *epollInterestList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *epollInterestEntry) StateTypeName() string { + return "pkg/sentry/vfs.epollInterestEntry" +} + +func (e *epollInterestEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *epollInterestEntry) beforeSave() {} + +// +checklocksignore +func (e *epollInterestEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *epollInterestEntry) afterLoad() {} + +// +checklocksignore +func (e *epollInterestEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (l *eventList) StateTypeName() string { + return "pkg/sentry/vfs.eventList" +} + +func (l *eventList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *eventList) beforeSave() {} + +// +checklocksignore +func (l *eventList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *eventList) afterLoad() {} + +// +checklocksignore +func (l *eventList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *eventEntry) StateTypeName() string { + return "pkg/sentry/vfs.eventEntry" +} + +func (e *eventEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *eventEntry) beforeSave() {} + +// +checklocksignore +func (e *eventEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *eventEntry) afterLoad() {} + +// +checklocksignore +func (e *eventEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (fd *FileDescription) StateTypeName() string { + return "pkg/sentry/vfs.FileDescription" +} + +func (fd *FileDescription) StateFields() []string { + return []string{ + "FileDescriptionRefs", + "statusFlags", + "asyncHandler", + "epolls", + "vd", + "opts", + "readable", + "writable", + "usedLockBSD", + "impl", + } +} + +// +checklocksignore +func (fd *FileDescription) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.FileDescriptionRefs) + stateSinkObject.Save(1, &fd.statusFlags) + stateSinkObject.Save(2, &fd.asyncHandler) + stateSinkObject.Save(3, &fd.epolls) + stateSinkObject.Save(4, &fd.vd) + stateSinkObject.Save(5, &fd.opts) + stateSinkObject.Save(6, &fd.readable) + stateSinkObject.Save(7, &fd.writable) + stateSinkObject.Save(8, &fd.usedLockBSD) + stateSinkObject.Save(9, &fd.impl) +} + +// +checklocksignore +func (fd *FileDescription) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.FileDescriptionRefs) + stateSourceObject.Load(1, &fd.statusFlags) + stateSourceObject.Load(2, &fd.asyncHandler) + stateSourceObject.Load(3, &fd.epolls) + stateSourceObject.Load(4, &fd.vd) + stateSourceObject.Load(5, &fd.opts) + stateSourceObject.Load(6, &fd.readable) + stateSourceObject.Load(7, &fd.writable) + stateSourceObject.Load(8, &fd.usedLockBSD) + stateSourceObject.Load(9, &fd.impl) + stateSourceObject.AfterLoad(fd.afterLoad) +} + +func (f *FileDescriptionOptions) StateTypeName() string { + return "pkg/sentry/vfs.FileDescriptionOptions" +} + +func (f *FileDescriptionOptions) StateFields() []string { + return []string{ + "AllowDirectIO", + "DenyPRead", + "DenyPWrite", + "UseDentryMetadata", + } +} + +func (f *FileDescriptionOptions) beforeSave() {} + +// +checklocksignore +func (f *FileDescriptionOptions) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.AllowDirectIO) + stateSinkObject.Save(1, &f.DenyPRead) + stateSinkObject.Save(2, &f.DenyPWrite) + stateSinkObject.Save(3, &f.UseDentryMetadata) +} + +func (f *FileDescriptionOptions) afterLoad() {} + +// +checklocksignore +func (f *FileDescriptionOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.AllowDirectIO) + stateSourceObject.Load(1, &f.DenyPRead) + stateSourceObject.Load(2, &f.DenyPWrite) + stateSourceObject.Load(3, &f.UseDentryMetadata) +} + +func (d *Dirent) StateTypeName() string { + return "pkg/sentry/vfs.Dirent" +} + +func (d *Dirent) StateFields() []string { + return []string{ + "Name", + "Type", + "Ino", + "NextOff", + } +} + +func (d *Dirent) beforeSave() {} + +// +checklocksignore +func (d *Dirent) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.Name) + stateSinkObject.Save(1, &d.Type) + stateSinkObject.Save(2, &d.Ino) + stateSinkObject.Save(3, &d.NextOff) +} + +func (d *Dirent) afterLoad() {} + +// +checklocksignore +func (d *Dirent) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.Name) + stateSourceObject.Load(1, &d.Type) + stateSourceObject.Load(2, &d.Ino) + stateSourceObject.Load(3, &d.NextOff) +} + +func (f *FileDescriptionDefaultImpl) StateTypeName() string { + return "pkg/sentry/vfs.FileDescriptionDefaultImpl" +} + +func (f *FileDescriptionDefaultImpl) StateFields() []string { + return []string{} +} + +func (f *FileDescriptionDefaultImpl) beforeSave() {} + +// +checklocksignore +func (f *FileDescriptionDefaultImpl) StateSave(stateSinkObject state.Sink) { + f.beforeSave() +} + +func (f *FileDescriptionDefaultImpl) afterLoad() {} + +// +checklocksignore +func (f *FileDescriptionDefaultImpl) StateLoad(stateSourceObject state.Source) { +} + +func (d *DirectoryFileDescriptionDefaultImpl) StateTypeName() string { + return "pkg/sentry/vfs.DirectoryFileDescriptionDefaultImpl" +} + +func (d *DirectoryFileDescriptionDefaultImpl) StateFields() []string { + return []string{} +} + +func (d *DirectoryFileDescriptionDefaultImpl) beforeSave() {} + +// +checklocksignore +func (d *DirectoryFileDescriptionDefaultImpl) StateSave(stateSinkObject state.Sink) { + d.beforeSave() +} + +func (d *DirectoryFileDescriptionDefaultImpl) afterLoad() {} + +// +checklocksignore +func (d *DirectoryFileDescriptionDefaultImpl) StateLoad(stateSourceObject state.Source) { +} + +func (d *DentryMetadataFileDescriptionImpl) StateTypeName() string { + return "pkg/sentry/vfs.DentryMetadataFileDescriptionImpl" +} + +func (d *DentryMetadataFileDescriptionImpl) StateFields() []string { + return []string{} +} + +func (d *DentryMetadataFileDescriptionImpl) beforeSave() {} + +// +checklocksignore +func (d *DentryMetadataFileDescriptionImpl) StateSave(stateSinkObject state.Sink) { + d.beforeSave() +} + +func (d *DentryMetadataFileDescriptionImpl) afterLoad() {} + +// +checklocksignore +func (d *DentryMetadataFileDescriptionImpl) StateLoad(stateSourceObject state.Source) { +} + +func (s *StaticData) StateTypeName() string { + return "pkg/sentry/vfs.StaticData" +} + +func (s *StaticData) StateFields() []string { + return []string{ + "Data", + } +} + +func (s *StaticData) beforeSave() {} + +// +checklocksignore +func (s *StaticData) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Data) +} + +func (s *StaticData) afterLoad() {} + +// +checklocksignore +func (s *StaticData) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Data) +} + +func (fd *DynamicBytesFileDescriptionImpl) StateTypeName() string { + return "pkg/sentry/vfs.DynamicBytesFileDescriptionImpl" +} + +func (fd *DynamicBytesFileDescriptionImpl) StateFields() []string { + return []string{ + "data", + "buf", + "off", + "lastRead", + } +} + +func (fd *DynamicBytesFileDescriptionImpl) beforeSave() {} + +// +checklocksignore +func (fd *DynamicBytesFileDescriptionImpl) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + var bufValue []byte + bufValue = fd.saveBuf() + stateSinkObject.SaveValue(1, bufValue) + stateSinkObject.Save(0, &fd.data) + stateSinkObject.Save(2, &fd.off) + stateSinkObject.Save(3, &fd.lastRead) +} + +func (fd *DynamicBytesFileDescriptionImpl) afterLoad() {} + +// +checklocksignore +func (fd *DynamicBytesFileDescriptionImpl) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.data) + stateSourceObject.Load(2, &fd.off) + stateSourceObject.Load(3, &fd.lastRead) + stateSourceObject.LoadValue(1, new([]byte), func(y interface{}) { fd.loadBuf(y.([]byte)) }) +} + +func (fd *LockFD) StateTypeName() string { + return "pkg/sentry/vfs.LockFD" +} + +func (fd *LockFD) StateFields() []string { + return []string{ + "locks", + } +} + +func (fd *LockFD) beforeSave() {} + +// +checklocksignore +func (fd *LockFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.locks) +} + +func (fd *LockFD) afterLoad() {} + +// +checklocksignore +func (fd *LockFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.locks) +} + +func (n *NoLockFD) StateTypeName() string { + return "pkg/sentry/vfs.NoLockFD" +} + +func (n *NoLockFD) StateFields() []string { + return []string{} +} + +func (n *NoLockFD) beforeSave() {} + +// +checklocksignore +func (n *NoLockFD) StateSave(stateSinkObject state.Sink) { + n.beforeSave() +} + +func (n *NoLockFD) afterLoad() {} + +// +checklocksignore +func (n *NoLockFD) StateLoad(stateSourceObject state.Source) { +} + +func (b *BadLockFD) StateTypeName() string { + return "pkg/sentry/vfs.BadLockFD" +} + +func (b *BadLockFD) StateFields() []string { + return []string{} +} + +func (b *BadLockFD) beforeSave() {} + +// +checklocksignore +func (b *BadLockFD) StateSave(stateSinkObject state.Sink) { + b.beforeSave() +} + +func (b *BadLockFD) afterLoad() {} + +// +checklocksignore +func (b *BadLockFD) StateLoad(stateSourceObject state.Source) { +} + +func (r *FileDescriptionRefs) StateTypeName() string { + return "pkg/sentry/vfs.FileDescriptionRefs" +} + +func (r *FileDescriptionRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *FileDescriptionRefs) beforeSave() {} + +// +checklocksignore +func (r *FileDescriptionRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *FileDescriptionRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (fs *Filesystem) StateTypeName() string { + return "pkg/sentry/vfs.Filesystem" +} + +func (fs *Filesystem) StateFields() []string { + return []string{ + "FilesystemRefs", + "vfs", + "fsType", + "impl", + } +} + +func (fs *Filesystem) beforeSave() {} + +// +checklocksignore +func (fs *Filesystem) StateSave(stateSinkObject state.Sink) { + fs.beforeSave() + stateSinkObject.Save(0, &fs.FilesystemRefs) + stateSinkObject.Save(1, &fs.vfs) + stateSinkObject.Save(2, &fs.fsType) + stateSinkObject.Save(3, &fs.impl) +} + +func (fs *Filesystem) afterLoad() {} + +// +checklocksignore +func (fs *Filesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fs.FilesystemRefs) + stateSourceObject.Load(1, &fs.vfs) + stateSourceObject.Load(2, &fs.fsType) + stateSourceObject.Load(3, &fs.impl) +} + +func (p *PrependPathAtVFSRootError) StateTypeName() string { + return "pkg/sentry/vfs.PrependPathAtVFSRootError" +} + +func (p *PrependPathAtVFSRootError) StateFields() []string { + return []string{} +} + +func (p *PrependPathAtVFSRootError) beforeSave() {} + +// +checklocksignore +func (p *PrependPathAtVFSRootError) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *PrependPathAtVFSRootError) afterLoad() {} + +// +checklocksignore +func (p *PrependPathAtVFSRootError) StateLoad(stateSourceObject state.Source) { +} + +func (p *PrependPathAtNonMountRootError) StateTypeName() string { + return "pkg/sentry/vfs.PrependPathAtNonMountRootError" +} + +func (p *PrependPathAtNonMountRootError) StateFields() []string { + return []string{} +} + +func (p *PrependPathAtNonMountRootError) beforeSave() {} + +// +checklocksignore +func (p *PrependPathAtNonMountRootError) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *PrependPathAtNonMountRootError) afterLoad() {} + +// +checklocksignore +func (p *PrependPathAtNonMountRootError) StateLoad(stateSourceObject state.Source) { +} + +func (p *PrependPathSyntheticError) StateTypeName() string { + return "pkg/sentry/vfs.PrependPathSyntheticError" +} + +func (p *PrependPathSyntheticError) StateFields() []string { + return []string{} +} + +func (p *PrependPathSyntheticError) beforeSave() {} + +// +checklocksignore +func (p *PrependPathSyntheticError) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *PrependPathSyntheticError) afterLoad() {} + +// +checklocksignore +func (p *PrependPathSyntheticError) StateLoad(stateSourceObject state.Source) { +} + +func (r *FilesystemRefs) StateTypeName() string { + return "pkg/sentry/vfs.FilesystemRefs" +} + +func (r *FilesystemRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *FilesystemRefs) beforeSave() {} + +// +checklocksignore +func (r *FilesystemRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *FilesystemRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (r *registeredFilesystemType) StateTypeName() string { + return "pkg/sentry/vfs.registeredFilesystemType" +} + +func (r *registeredFilesystemType) StateFields() []string { + return []string{ + "fsType", + "opts", + } +} + +func (r *registeredFilesystemType) beforeSave() {} + +// +checklocksignore +func (r *registeredFilesystemType) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.fsType) + stateSinkObject.Save(1, &r.opts) +} + +func (r *registeredFilesystemType) afterLoad() {} + +// +checklocksignore +func (r *registeredFilesystemType) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.fsType) + stateSourceObject.Load(1, &r.opts) +} + +func (r *RegisterFilesystemTypeOptions) StateTypeName() string { + return "pkg/sentry/vfs.RegisterFilesystemTypeOptions" +} + +func (r *RegisterFilesystemTypeOptions) StateFields() []string { + return []string{ + "AllowUserMount", + "AllowUserList", + "RequiresDevice", + } +} + +func (r *RegisterFilesystemTypeOptions) beforeSave() {} + +// +checklocksignore +func (r *RegisterFilesystemTypeOptions) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.AllowUserMount) + stateSinkObject.Save(1, &r.AllowUserList) + stateSinkObject.Save(2, &r.RequiresDevice) +} + +func (r *RegisterFilesystemTypeOptions) afterLoad() {} + +// +checklocksignore +func (r *RegisterFilesystemTypeOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.AllowUserMount) + stateSourceObject.Load(1, &r.AllowUserList) + stateSourceObject.Load(2, &r.RequiresDevice) +} + +func (e *EventType) StateTypeName() string { + return "pkg/sentry/vfs.EventType" +} + +func (e *EventType) StateFields() []string { + return nil +} + +func (i *Inotify) StateTypeName() string { + return "pkg/sentry/vfs.Inotify" +} + +func (i *Inotify) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "NoLockFD", + "id", + "queue", + "events", + "scratch", + "nextWatchMinusOne", + "watches", + } +} + +func (i *Inotify) beforeSave() {} + +// +checklocksignore +func (i *Inotify) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.vfsfd) + stateSinkObject.Save(1, &i.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &i.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &i.NoLockFD) + stateSinkObject.Save(4, &i.id) + stateSinkObject.Save(5, &i.queue) + stateSinkObject.Save(6, &i.events) + stateSinkObject.Save(7, &i.scratch) + stateSinkObject.Save(8, &i.nextWatchMinusOne) + stateSinkObject.Save(9, &i.watches) +} + +func (i *Inotify) afterLoad() {} + +// +checklocksignore +func (i *Inotify) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.vfsfd) + stateSourceObject.Load(1, &i.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &i.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &i.NoLockFD) + stateSourceObject.Load(4, &i.id) + stateSourceObject.Load(5, &i.queue) + stateSourceObject.Load(6, &i.events) + stateSourceObject.Load(7, &i.scratch) + stateSourceObject.Load(8, &i.nextWatchMinusOne) + stateSourceObject.Load(9, &i.watches) +} + +func (w *Watches) StateTypeName() string { + return "pkg/sentry/vfs.Watches" +} + +func (w *Watches) StateFields() []string { + return []string{ + "ws", + } +} + +func (w *Watches) beforeSave() {} + +// +checklocksignore +func (w *Watches) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.ws) +} + +func (w *Watches) afterLoad() {} + +// +checklocksignore +func (w *Watches) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.ws) +} + +func (w *Watch) StateTypeName() string { + return "pkg/sentry/vfs.Watch" +} + +func (w *Watch) StateFields() []string { + return []string{ + "owner", + "wd", + "target", + "mask", + "expired", + } +} + +func (w *Watch) beforeSave() {} + +// +checklocksignore +func (w *Watch) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.owner) + stateSinkObject.Save(1, &w.wd) + stateSinkObject.Save(2, &w.target) + stateSinkObject.Save(3, &w.mask) + stateSinkObject.Save(4, &w.expired) +} + +func (w *Watch) afterLoad() {} + +// +checklocksignore +func (w *Watch) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.owner) + stateSourceObject.Load(1, &w.wd) + stateSourceObject.Load(2, &w.target) + stateSourceObject.Load(3, &w.mask) + stateSourceObject.Load(4, &w.expired) +} + +func (e *Event) StateTypeName() string { + return "pkg/sentry/vfs.Event" +} + +func (e *Event) StateFields() []string { + return []string{ + "eventEntry", + "wd", + "mask", + "cookie", + "len", + "name", + } +} + +func (e *Event) beforeSave() {} + +// +checklocksignore +func (e *Event) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.eventEntry) + stateSinkObject.Save(1, &e.wd) + stateSinkObject.Save(2, &e.mask) + stateSinkObject.Save(3, &e.cookie) + stateSinkObject.Save(4, &e.len) + stateSinkObject.Save(5, &e.name) +} + +func (e *Event) afterLoad() {} + +// +checklocksignore +func (e *Event) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.eventEntry) + stateSourceObject.Load(1, &e.wd) + stateSourceObject.Load(2, &e.mask) + stateSourceObject.Load(3, &e.cookie) + stateSourceObject.Load(4, &e.len) + stateSourceObject.Load(5, &e.name) +} + +func (fl *FileLocks) StateTypeName() string { + return "pkg/sentry/vfs.FileLocks" +} + +func (fl *FileLocks) StateFields() []string { + return []string{ + "bsd", + "posix", + } +} + +func (fl *FileLocks) beforeSave() {} + +// +checklocksignore +func (fl *FileLocks) StateSave(stateSinkObject state.Sink) { + fl.beforeSave() + stateSinkObject.Save(0, &fl.bsd) + stateSinkObject.Save(1, &fl.posix) +} + +func (fl *FileLocks) afterLoad() {} + +// +checklocksignore +func (fl *FileLocks) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fl.bsd) + stateSourceObject.Load(1, &fl.posix) +} + +func (mnt *Mount) StateTypeName() string { + return "pkg/sentry/vfs.Mount" +} + +func (mnt *Mount) StateFields() []string { + return []string{ + "vfs", + "fs", + "root", + "ID", + "Flags", + "key", + "ns", + "refs", + "children", + "umounted", + "writers", + } +} + +func (mnt *Mount) beforeSave() {} + +// +checklocksignore +func (mnt *Mount) StateSave(stateSinkObject state.Sink) { + mnt.beforeSave() + var keyValue VirtualDentry + keyValue = mnt.saveKey() + stateSinkObject.SaveValue(5, keyValue) + stateSinkObject.Save(0, &mnt.vfs) + stateSinkObject.Save(1, &mnt.fs) + stateSinkObject.Save(2, &mnt.root) + stateSinkObject.Save(3, &mnt.ID) + stateSinkObject.Save(4, &mnt.Flags) + stateSinkObject.Save(6, &mnt.ns) + stateSinkObject.Save(7, &mnt.refs) + stateSinkObject.Save(8, &mnt.children) + stateSinkObject.Save(9, &mnt.umounted) + stateSinkObject.Save(10, &mnt.writers) +} + +// +checklocksignore +func (mnt *Mount) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mnt.vfs) + stateSourceObject.Load(1, &mnt.fs) + stateSourceObject.Load(2, &mnt.root) + stateSourceObject.Load(3, &mnt.ID) + stateSourceObject.Load(4, &mnt.Flags) + stateSourceObject.Load(6, &mnt.ns) + stateSourceObject.Load(7, &mnt.refs) + stateSourceObject.Load(8, &mnt.children) + stateSourceObject.Load(9, &mnt.umounted) + stateSourceObject.Load(10, &mnt.writers) + stateSourceObject.LoadValue(5, new(VirtualDentry), func(y interface{}) { mnt.loadKey(y.(VirtualDentry)) }) + stateSourceObject.AfterLoad(mnt.afterLoad) +} + +func (mntns *MountNamespace) StateTypeName() string { + return "pkg/sentry/vfs.MountNamespace" +} + +func (mntns *MountNamespace) StateFields() []string { + return []string{ + "MountNamespaceRefs", + "Owner", + "root", + "mountpoints", + } +} + +func (mntns *MountNamespace) beforeSave() {} + +// +checklocksignore +func (mntns *MountNamespace) StateSave(stateSinkObject state.Sink) { + mntns.beforeSave() + stateSinkObject.Save(0, &mntns.MountNamespaceRefs) + stateSinkObject.Save(1, &mntns.Owner) + stateSinkObject.Save(2, &mntns.root) + stateSinkObject.Save(3, &mntns.mountpoints) +} + +func (mntns *MountNamespace) afterLoad() {} + +// +checklocksignore +func (mntns *MountNamespace) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mntns.MountNamespaceRefs) + stateSourceObject.Load(1, &mntns.Owner) + stateSourceObject.Load(2, &mntns.root) + stateSourceObject.Load(3, &mntns.mountpoints) +} + +func (u *umountRecursiveOptions) StateTypeName() string { + return "pkg/sentry/vfs.umountRecursiveOptions" +} + +func (u *umountRecursiveOptions) StateFields() []string { + return []string{ + "eager", + "disconnectHierarchy", + } +} + +func (u *umountRecursiveOptions) beforeSave() {} + +// +checklocksignore +func (u *umountRecursiveOptions) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.eager) + stateSinkObject.Save(1, &u.disconnectHierarchy) +} + +func (u *umountRecursiveOptions) afterLoad() {} + +// +checklocksignore +func (u *umountRecursiveOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.eager) + stateSourceObject.Load(1, &u.disconnectHierarchy) +} + +func (r *MountNamespaceRefs) StateTypeName() string { + return "pkg/sentry/vfs.MountNamespaceRefs" +} + +func (r *MountNamespaceRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *MountNamespaceRefs) beforeSave() {} + +// +checklocksignore +func (r *MountNamespaceRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *MountNamespaceRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (fd *opathFD) StateTypeName() string { + return "pkg/sentry/vfs.opathFD" +} + +func (fd *opathFD) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "BadLockFD", + } +} + +func (fd *opathFD) beforeSave() {} + +// +checklocksignore +func (fd *opathFD) StateSave(stateSinkObject state.Sink) { + fd.beforeSave() + stateSinkObject.Save(0, &fd.vfsfd) + stateSinkObject.Save(1, &fd.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &fd.BadLockFD) +} + +func (fd *opathFD) afterLoad() {} + +// +checklocksignore +func (fd *opathFD) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fd.vfsfd) + stateSourceObject.Load(1, &fd.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &fd.BadLockFD) +} + +func (g *GetDentryOptions) StateTypeName() string { + return "pkg/sentry/vfs.GetDentryOptions" +} + +func (g *GetDentryOptions) StateFields() []string { + return []string{ + "CheckSearchable", + } +} + +func (g *GetDentryOptions) beforeSave() {} + +// +checklocksignore +func (g *GetDentryOptions) StateSave(stateSinkObject state.Sink) { + g.beforeSave() + stateSinkObject.Save(0, &g.CheckSearchable) +} + +func (g *GetDentryOptions) afterLoad() {} + +// +checklocksignore +func (g *GetDentryOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &g.CheckSearchable) +} + +func (m *MkdirOptions) StateTypeName() string { + return "pkg/sentry/vfs.MkdirOptions" +} + +func (m *MkdirOptions) StateFields() []string { + return []string{ + "Mode", + "ForSyntheticMountpoint", + } +} + +func (m *MkdirOptions) beforeSave() {} + +// +checklocksignore +func (m *MkdirOptions) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.Mode) + stateSinkObject.Save(1, &m.ForSyntheticMountpoint) +} + +func (m *MkdirOptions) afterLoad() {} + +// +checklocksignore +func (m *MkdirOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.Mode) + stateSourceObject.Load(1, &m.ForSyntheticMountpoint) +} + +func (m *MknodOptions) StateTypeName() string { + return "pkg/sentry/vfs.MknodOptions" +} + +func (m *MknodOptions) StateFields() []string { + return []string{ + "Mode", + "DevMajor", + "DevMinor", + "Endpoint", + } +} + +func (m *MknodOptions) beforeSave() {} + +// +checklocksignore +func (m *MknodOptions) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.Mode) + stateSinkObject.Save(1, &m.DevMajor) + stateSinkObject.Save(2, &m.DevMinor) + stateSinkObject.Save(3, &m.Endpoint) +} + +func (m *MknodOptions) afterLoad() {} + +// +checklocksignore +func (m *MknodOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.Mode) + stateSourceObject.Load(1, &m.DevMajor) + stateSourceObject.Load(2, &m.DevMinor) + stateSourceObject.Load(3, &m.Endpoint) +} + +func (m *MountFlags) StateTypeName() string { + return "pkg/sentry/vfs.MountFlags" +} + +func (m *MountFlags) StateFields() []string { + return []string{ + "NoExec", + "NoATime", + "NoDev", + "NoSUID", + } +} + +func (m *MountFlags) beforeSave() {} + +// +checklocksignore +func (m *MountFlags) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.NoExec) + stateSinkObject.Save(1, &m.NoATime) + stateSinkObject.Save(2, &m.NoDev) + stateSinkObject.Save(3, &m.NoSUID) +} + +func (m *MountFlags) afterLoad() {} + +// +checklocksignore +func (m *MountFlags) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.NoExec) + stateSourceObject.Load(1, &m.NoATime) + stateSourceObject.Load(2, &m.NoDev) + stateSourceObject.Load(3, &m.NoSUID) +} + +func (m *MountOptions) StateTypeName() string { + return "pkg/sentry/vfs.MountOptions" +} + +func (m *MountOptions) StateFields() []string { + return []string{ + "Flags", + "ReadOnly", + "GetFilesystemOptions", + "InternalMount", + } +} + +func (m *MountOptions) beforeSave() {} + +// +checklocksignore +func (m *MountOptions) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.Flags) + stateSinkObject.Save(1, &m.ReadOnly) + stateSinkObject.Save(2, &m.GetFilesystemOptions) + stateSinkObject.Save(3, &m.InternalMount) +} + +func (m *MountOptions) afterLoad() {} + +// +checklocksignore +func (m *MountOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.Flags) + stateSourceObject.Load(1, &m.ReadOnly) + stateSourceObject.Load(2, &m.GetFilesystemOptions) + stateSourceObject.Load(3, &m.InternalMount) +} + +func (o *OpenOptions) StateTypeName() string { + return "pkg/sentry/vfs.OpenOptions" +} + +func (o *OpenOptions) StateFields() []string { + return []string{ + "Flags", + "Mode", + "FileExec", + } +} + +func (o *OpenOptions) beforeSave() {} + +// +checklocksignore +func (o *OpenOptions) StateSave(stateSinkObject state.Sink) { + o.beforeSave() + stateSinkObject.Save(0, &o.Flags) + stateSinkObject.Save(1, &o.Mode) + stateSinkObject.Save(2, &o.FileExec) +} + +func (o *OpenOptions) afterLoad() {} + +// +checklocksignore +func (o *OpenOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &o.Flags) + stateSourceObject.Load(1, &o.Mode) + stateSourceObject.Load(2, &o.FileExec) +} + +func (r *ReadOptions) StateTypeName() string { + return "pkg/sentry/vfs.ReadOptions" +} + +func (r *ReadOptions) StateFields() []string { + return []string{ + "Flags", + } +} + +func (r *ReadOptions) beforeSave() {} + +// +checklocksignore +func (r *ReadOptions) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Flags) +} + +func (r *ReadOptions) afterLoad() {} + +// +checklocksignore +func (r *ReadOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Flags) +} + +func (r *RenameOptions) StateTypeName() string { + return "pkg/sentry/vfs.RenameOptions" +} + +func (r *RenameOptions) StateFields() []string { + return []string{ + "Flags", + "MustBeDir", + } +} + +func (r *RenameOptions) beforeSave() {} + +// +checklocksignore +func (r *RenameOptions) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Flags) + stateSinkObject.Save(1, &r.MustBeDir) +} + +func (r *RenameOptions) afterLoad() {} + +// +checklocksignore +func (r *RenameOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Flags) + stateSourceObject.Load(1, &r.MustBeDir) +} + +func (s *SetStatOptions) StateTypeName() string { + return "pkg/sentry/vfs.SetStatOptions" +} + +func (s *SetStatOptions) StateFields() []string { + return []string{ + "Stat", + "NeedWritePerm", + } +} + +func (s *SetStatOptions) beforeSave() {} + +// +checklocksignore +func (s *SetStatOptions) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Stat) + stateSinkObject.Save(1, &s.NeedWritePerm) +} + +func (s *SetStatOptions) afterLoad() {} + +// +checklocksignore +func (s *SetStatOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Stat) + stateSourceObject.Load(1, &s.NeedWritePerm) +} + +func (b *BoundEndpointOptions) StateTypeName() string { + return "pkg/sentry/vfs.BoundEndpointOptions" +} + +func (b *BoundEndpointOptions) StateFields() []string { + return []string{ + "Addr", + } +} + +func (b *BoundEndpointOptions) beforeSave() {} + +// +checklocksignore +func (b *BoundEndpointOptions) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.Addr) +} + +func (b *BoundEndpointOptions) afterLoad() {} + +// +checklocksignore +func (b *BoundEndpointOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.Addr) +} + +func (g *GetXattrOptions) StateTypeName() string { + return "pkg/sentry/vfs.GetXattrOptions" +} + +func (g *GetXattrOptions) StateFields() []string { + return []string{ + "Name", + "Size", + } +} + +func (g *GetXattrOptions) beforeSave() {} + +// +checklocksignore +func (g *GetXattrOptions) StateSave(stateSinkObject state.Sink) { + g.beforeSave() + stateSinkObject.Save(0, &g.Name) + stateSinkObject.Save(1, &g.Size) +} + +func (g *GetXattrOptions) afterLoad() {} + +// +checklocksignore +func (g *GetXattrOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &g.Name) + stateSourceObject.Load(1, &g.Size) +} + +func (s *SetXattrOptions) StateTypeName() string { + return "pkg/sentry/vfs.SetXattrOptions" +} + +func (s *SetXattrOptions) StateFields() []string { + return []string{ + "Name", + "Value", + "Flags", + } +} + +func (s *SetXattrOptions) beforeSave() {} + +// +checklocksignore +func (s *SetXattrOptions) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Name) + stateSinkObject.Save(1, &s.Value) + stateSinkObject.Save(2, &s.Flags) +} + +func (s *SetXattrOptions) afterLoad() {} + +// +checklocksignore +func (s *SetXattrOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Name) + stateSourceObject.Load(1, &s.Value) + stateSourceObject.Load(2, &s.Flags) +} + +func (s *StatOptions) StateTypeName() string { + return "pkg/sentry/vfs.StatOptions" +} + +func (s *StatOptions) StateFields() []string { + return []string{ + "Mask", + "Sync", + } +} + +func (s *StatOptions) beforeSave() {} + +// +checklocksignore +func (s *StatOptions) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Mask) + stateSinkObject.Save(1, &s.Sync) +} + +func (s *StatOptions) afterLoad() {} + +// +checklocksignore +func (s *StatOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Mask) + stateSourceObject.Load(1, &s.Sync) +} + +func (u *UmountOptions) StateTypeName() string { + return "pkg/sentry/vfs.UmountOptions" +} + +func (u *UmountOptions) StateFields() []string { + return []string{ + "Flags", + } +} + +func (u *UmountOptions) beforeSave() {} + +// +checklocksignore +func (u *UmountOptions) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.Flags) +} + +func (u *UmountOptions) afterLoad() {} + +// +checklocksignore +func (u *UmountOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.Flags) +} + +func (w *WriteOptions) StateTypeName() string { + return "pkg/sentry/vfs.WriteOptions" +} + +func (w *WriteOptions) StateFields() []string { + return []string{ + "Flags", + } +} + +func (w *WriteOptions) beforeSave() {} + +// +checklocksignore +func (w *WriteOptions) StateSave(stateSinkObject state.Sink) { + w.beforeSave() + stateSinkObject.Save(0, &w.Flags) +} + +func (w *WriteOptions) afterLoad() {} + +// +checklocksignore +func (w *WriteOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &w.Flags) +} + +func (a *AccessTypes) StateTypeName() string { + return "pkg/sentry/vfs.AccessTypes" +} + +func (a *AccessTypes) StateFields() []string { + return nil +} + +func (rp *ResolvingPath) StateTypeName() string { + return "pkg/sentry/vfs.ResolvingPath" +} + +func (rp *ResolvingPath) StateFields() []string { + return []string{ + "vfs", + "root", + "mount", + "start", + "pit", + "flags", + "mustBeDir", + "symlinks", + "curPart", + "creds", + "nextMount", + "nextStart", + "absSymlinkTarget", + "parts", + } +} + +func (rp *ResolvingPath) beforeSave() {} + +// +checklocksignore +func (rp *ResolvingPath) StateSave(stateSinkObject state.Sink) { + rp.beforeSave() + stateSinkObject.Save(0, &rp.vfs) + stateSinkObject.Save(1, &rp.root) + stateSinkObject.Save(2, &rp.mount) + stateSinkObject.Save(3, &rp.start) + stateSinkObject.Save(4, &rp.pit) + stateSinkObject.Save(5, &rp.flags) + stateSinkObject.Save(6, &rp.mustBeDir) + stateSinkObject.Save(7, &rp.symlinks) + stateSinkObject.Save(8, &rp.curPart) + stateSinkObject.Save(9, &rp.creds) + stateSinkObject.Save(10, &rp.nextMount) + stateSinkObject.Save(11, &rp.nextStart) + stateSinkObject.Save(12, &rp.absSymlinkTarget) + stateSinkObject.Save(13, &rp.parts) +} + +func (rp *ResolvingPath) afterLoad() {} + +// +checklocksignore +func (rp *ResolvingPath) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rp.vfs) + stateSourceObject.Load(1, &rp.root) + stateSourceObject.Load(2, &rp.mount) + stateSourceObject.Load(3, &rp.start) + stateSourceObject.Load(4, &rp.pit) + stateSourceObject.Load(5, &rp.flags) + stateSourceObject.Load(6, &rp.mustBeDir) + stateSourceObject.Load(7, &rp.symlinks) + stateSourceObject.Load(8, &rp.curPart) + stateSourceObject.Load(9, &rp.creds) + stateSourceObject.Load(10, &rp.nextMount) + stateSourceObject.Load(11, &rp.nextStart) + stateSourceObject.Load(12, &rp.absSymlinkTarget) + stateSourceObject.Load(13, &rp.parts) +} + +func (r *resolveMountRootOrJumpError) StateTypeName() string { + return "pkg/sentry/vfs.resolveMountRootOrJumpError" +} + +func (r *resolveMountRootOrJumpError) StateFields() []string { + return []string{} +} + +func (r *resolveMountRootOrJumpError) beforeSave() {} + +// +checklocksignore +func (r *resolveMountRootOrJumpError) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *resolveMountRootOrJumpError) afterLoad() {} + +// +checklocksignore +func (r *resolveMountRootOrJumpError) StateLoad(stateSourceObject state.Source) { +} + +func (r *resolveMountPointError) StateTypeName() string { + return "pkg/sentry/vfs.resolveMountPointError" +} + +func (r *resolveMountPointError) StateFields() []string { + return []string{} +} + +func (r *resolveMountPointError) beforeSave() {} + +// +checklocksignore +func (r *resolveMountPointError) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *resolveMountPointError) afterLoad() {} + +// +checklocksignore +func (r *resolveMountPointError) StateLoad(stateSourceObject state.Source) { +} + +func (r *resolveAbsSymlinkError) StateTypeName() string { + return "pkg/sentry/vfs.resolveAbsSymlinkError" +} + +func (r *resolveAbsSymlinkError) StateFields() []string { + return []string{} +} + +func (r *resolveAbsSymlinkError) beforeSave() {} + +// +checklocksignore +func (r *resolveAbsSymlinkError) StateSave(stateSinkObject state.Sink) { + r.beforeSave() +} + +func (r *resolveAbsSymlinkError) afterLoad() {} + +// +checklocksignore +func (r *resolveAbsSymlinkError) StateLoad(stateSourceObject state.Source) { +} + +func (vfs *VirtualFilesystem) StateTypeName() string { + return "pkg/sentry/vfs.VirtualFilesystem" +} + +func (vfs *VirtualFilesystem) StateFields() []string { + return []string{ + "mounts", + "mountpoints", + "lastMountID", + "anonMount", + "devices", + "anonBlockDevMinorNext", + "anonBlockDevMinor", + "fsTypes", + "filesystems", + } +} + +func (vfs *VirtualFilesystem) beforeSave() {} + +// +checklocksignore +func (vfs *VirtualFilesystem) StateSave(stateSinkObject state.Sink) { + vfs.beforeSave() + var mountsValue []*Mount + mountsValue = vfs.saveMounts() + stateSinkObject.SaveValue(0, mountsValue) + stateSinkObject.Save(1, &vfs.mountpoints) + stateSinkObject.Save(2, &vfs.lastMountID) + stateSinkObject.Save(3, &vfs.anonMount) + stateSinkObject.Save(4, &vfs.devices) + stateSinkObject.Save(5, &vfs.anonBlockDevMinorNext) + stateSinkObject.Save(6, &vfs.anonBlockDevMinor) + stateSinkObject.Save(7, &vfs.fsTypes) + stateSinkObject.Save(8, &vfs.filesystems) +} + +func (vfs *VirtualFilesystem) afterLoad() {} + +// +checklocksignore +func (vfs *VirtualFilesystem) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(1, &vfs.mountpoints) + stateSourceObject.Load(2, &vfs.lastMountID) + stateSourceObject.Load(3, &vfs.anonMount) + stateSourceObject.Load(4, &vfs.devices) + stateSourceObject.Load(5, &vfs.anonBlockDevMinorNext) + stateSourceObject.Load(6, &vfs.anonBlockDevMinor) + stateSourceObject.Load(7, &vfs.fsTypes) + stateSourceObject.Load(8, &vfs.filesystems) + stateSourceObject.LoadValue(0, new([]*Mount), func(y interface{}) { vfs.loadMounts(y.([]*Mount)) }) +} + +func (p *PathOperation) StateTypeName() string { + return "pkg/sentry/vfs.PathOperation" +} + +func (p *PathOperation) StateFields() []string { + return []string{ + "Root", + "Start", + "Path", + "FollowFinalSymlink", + } +} + +func (p *PathOperation) beforeSave() {} + +// +checklocksignore +func (p *PathOperation) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.Root) + stateSinkObject.Save(1, &p.Start) + stateSinkObject.Save(2, &p.Path) + stateSinkObject.Save(3, &p.FollowFinalSymlink) +} + +func (p *PathOperation) afterLoad() {} + +// +checklocksignore +func (p *PathOperation) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.Root) + stateSourceObject.Load(1, &p.Start) + stateSourceObject.Load(2, &p.Path) + stateSourceObject.Load(3, &p.FollowFinalSymlink) +} + +func (vd *VirtualDentry) StateTypeName() string { + return "pkg/sentry/vfs.VirtualDentry" +} + +func (vd *VirtualDentry) StateFields() []string { + return []string{ + "mount", + "dentry", + } +} + +func (vd *VirtualDentry) beforeSave() {} + +// +checklocksignore +func (vd *VirtualDentry) StateSave(stateSinkObject state.Sink) { + vd.beforeSave() + stateSinkObject.Save(0, &vd.mount) + stateSinkObject.Save(1, &vd.dentry) +} + +func (vd *VirtualDentry) afterLoad() {} + +// +checklocksignore +func (vd *VirtualDentry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &vd.mount) + stateSourceObject.Load(1, &vd.dentry) +} + +func init() { + state.Register((*anonFilesystemType)(nil)) + state.Register((*anonFilesystem)(nil)) + state.Register((*anonDentry)(nil)) + state.Register((*Dentry)(nil)) + state.Register((*DeviceKind)(nil)) + state.Register((*devTuple)(nil)) + state.Register((*registeredDevice)(nil)) + state.Register((*RegisterDeviceOptions)(nil)) + state.Register((*EpollInstance)(nil)) + state.Register((*epollInterestKey)(nil)) + state.Register((*epollInterest)(nil)) + state.Register((*epollInterestList)(nil)) + state.Register((*epollInterestEntry)(nil)) + state.Register((*eventList)(nil)) + state.Register((*eventEntry)(nil)) + state.Register((*FileDescription)(nil)) + state.Register((*FileDescriptionOptions)(nil)) + state.Register((*Dirent)(nil)) + state.Register((*FileDescriptionDefaultImpl)(nil)) + state.Register((*DirectoryFileDescriptionDefaultImpl)(nil)) + state.Register((*DentryMetadataFileDescriptionImpl)(nil)) + state.Register((*StaticData)(nil)) + state.Register((*DynamicBytesFileDescriptionImpl)(nil)) + state.Register((*LockFD)(nil)) + state.Register((*NoLockFD)(nil)) + state.Register((*BadLockFD)(nil)) + state.Register((*FileDescriptionRefs)(nil)) + state.Register((*Filesystem)(nil)) + state.Register((*PrependPathAtVFSRootError)(nil)) + state.Register((*PrependPathAtNonMountRootError)(nil)) + state.Register((*PrependPathSyntheticError)(nil)) + state.Register((*FilesystemRefs)(nil)) + state.Register((*registeredFilesystemType)(nil)) + state.Register((*RegisterFilesystemTypeOptions)(nil)) + state.Register((*EventType)(nil)) + state.Register((*Inotify)(nil)) + state.Register((*Watches)(nil)) + state.Register((*Watch)(nil)) + state.Register((*Event)(nil)) + state.Register((*FileLocks)(nil)) + state.Register((*Mount)(nil)) + state.Register((*MountNamespace)(nil)) + state.Register((*umountRecursiveOptions)(nil)) + state.Register((*MountNamespaceRefs)(nil)) + state.Register((*opathFD)(nil)) + state.Register((*GetDentryOptions)(nil)) + state.Register((*MkdirOptions)(nil)) + state.Register((*MknodOptions)(nil)) + state.Register((*MountFlags)(nil)) + state.Register((*MountOptions)(nil)) + state.Register((*OpenOptions)(nil)) + state.Register((*ReadOptions)(nil)) + state.Register((*RenameOptions)(nil)) + state.Register((*SetStatOptions)(nil)) + state.Register((*BoundEndpointOptions)(nil)) + state.Register((*GetXattrOptions)(nil)) + state.Register((*SetXattrOptions)(nil)) + state.Register((*StatOptions)(nil)) + state.Register((*UmountOptions)(nil)) + state.Register((*WriteOptions)(nil)) + state.Register((*AccessTypes)(nil)) + state.Register((*ResolvingPath)(nil)) + state.Register((*resolveMountRootOrJumpError)(nil)) + state.Register((*resolveMountPointError)(nil)) + state.Register((*resolveAbsSymlinkError)(nil)) + state.Register((*VirtualFilesystem)(nil)) + state.Register((*PathOperation)(nil)) + state.Register((*VirtualDentry)(nil)) +} diff --git a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go new file mode 100644 index 000000000..20f06c953 --- /dev/null +++ b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package vfs diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD deleted file mode 100644 index 1c5a1c9b6..000000000 --- a/pkg/sentry/watchdog/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "watchdog", - srcs = ["watchdog.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sync", - ], -) diff --git a/pkg/sentry/watchdog/watchdog_state_autogen.go b/pkg/sentry/watchdog/watchdog_state_autogen.go new file mode 100644 index 000000000..bce0200e7 --- /dev/null +++ b/pkg/sentry/watchdog/watchdog_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package watchdog diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD deleted file mode 100644 index 367765209..000000000 --- a/pkg/shim/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "shim", - srcs = [ - "api.go", - "debug.go", - "epoll.go", - "options.go", - "service.go", - "service_linux.go", - "state.go", - ], - visibility = ["//shim:__subpackages__"], - deps = [ - "//pkg/cleanup", - "//pkg/shim/proc", - "//pkg/shim/runsc", - "//pkg/shim/runtimeoptions", - "//pkg/shim/utils", - "//runsc/specutils", - "@com_github_burntsushi_toml//:go_default_library", - "@com_github_containerd_cgroups//:go_default_library", - "@com_github_containerd_cgroups//stats/v1:go_default_library", - "@com_github_containerd_console//:go_default_library", - "@com_github_containerd_containerd//api/events:go_default_library", - "@com_github_containerd_containerd//api/types/task:go_default_library", - "@com_github_containerd_containerd//errdefs:go_default_library", - "@com_github_containerd_containerd//events:go_default_library", - "@com_github_containerd_containerd//log:go_default_library", - "@com_github_containerd_containerd//mount:go_default_library", - "@com_github_containerd_containerd//namespaces:go_default_library", - "@com_github_containerd_containerd//pkg/process:go_default_library", - "@com_github_containerd_containerd//pkg/stdio:go_default_library", - "@com_github_containerd_containerd//runtime:go_default_library", - "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", - "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", - "@com_github_containerd_containerd//runtime/v2/task:go_default_library", - "@com_github_containerd_containerd//sys/reaper:go_default_library", - "@com_github_containerd_fifo//:go_default_library", - "@com_github_containerd_typeurl//:go_default_library", - "@com_github_gogo_protobuf//types:go_default_library", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - "@com_github_sirupsen_logrus//:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "shim_test", - size = "small", - srcs = ["service_test.go"], - library = ":shim", - deps = [ - "//pkg/shim/utils", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - ], -) diff --git a/pkg/shim/proc/BUILD b/pkg/shim/proc/BUILD deleted file mode 100644 index c8527a6d9..000000000 --- a/pkg/shim/proc/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "proc", - srcs = [ - "deleted_state.go", - "exec.go", - "exec_state.go", - "init.go", - "init_state.go", - "io.go", - "proc.go", - "types.go", - "utils.go", - ], - visibility = [ - "//pkg/shim:__subpackages__", - "//shim:__subpackages__", - ], - deps = [ - "//pkg/cleanup", - "//pkg/shim/runsc", - "//pkg/shim/utils", - "@com_github_containerd_console//:go_default_library", - "@com_github_containerd_containerd//errdefs:go_default_library", - "@com_github_containerd_containerd//log:go_default_library", - "@com_github_containerd_containerd//mount:go_default_library", - "@com_github_containerd_containerd//pkg/process:go_default_library", - "@com_github_containerd_containerd//pkg/stdio:go_default_library", - "@com_github_containerd_fifo//:go_default_library", - "@com_github_containerd_go_runc//:go_default_library", - "@com_github_gogo_protobuf//types:go_default_library", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/shim/proc/proc_state_autogen.go b/pkg/shim/proc/proc_state_autogen.go new file mode 100644 index 000000000..210252d9d --- /dev/null +++ b/pkg/shim/proc/proc_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package proc diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD deleted file mode 100644 index f2fc813e6..000000000 --- a/pkg/shim/runsc/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "runsc", - srcs = [ - "runsc.go", - "utils.go", - ], - visibility = ["//:sandbox"], - deps = [ - "@com_github_containerd_containerd//log:go_default_library", - "@com_github_containerd_go_runc//:go_default_library", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/shim/runsc/runsc_state_autogen.go b/pkg/shim/runsc/runsc_state_autogen.go new file mode 100644 index 000000000..ee470594f --- /dev/null +++ b/pkg/shim/runsc/runsc_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package runsc diff --git a/pkg/shim/runtimeoptions/BUILD b/pkg/shim/runtimeoptions/BUILD deleted file mode 100644 index 029be7c09..000000000 --- a/pkg/shim/runtimeoptions/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") - -package(licenses = ["notice"]) - -proto_library( - name = "api", - srcs = [ - "runtimeoptions.proto", - ], -) - -go_library( - name = "runtimeoptions", - srcs = [ - "runtimeoptions.go", - "runtimeoptions_cri.go", - ], - visibility = ["//pkg/shim:__pkg__"], - deps = ["@com_github_gogo_protobuf//proto:go_default_library"], -) - -go_test( - name = "runtimeoptions_test", - size = "small", - srcs = ["runtimeoptions_test.go"], - library = ":runtimeoptions", - deps = [ - "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", - "@com_github_containerd_typeurl//:go_default_library", - "@com_github_gogo_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/shim/runtimeoptions/runtimeoptions.proto b/pkg/shim/runtimeoptions/runtimeoptions.proto deleted file mode 100644 index 057032e34..000000000 --- a/pkg/shim/runtimeoptions/runtimeoptions.proto +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package cri.runtimeoptions.v1; - -// This is a version of the runtimeoptions CRI API that is vendored. -// -// Importing the full CRI package is a nightmare. -message Options { - string type_url = 1; - string config_path = 2; -} diff --git a/pkg/shim/runtimeoptions/runtimeoptions_state_autogen.go b/pkg/shim/runtimeoptions/runtimeoptions_state_autogen.go new file mode 100644 index 000000000..73e0a158d --- /dev/null +++ b/pkg/shim/runtimeoptions/runtimeoptions_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build go1.1 +// +build go1.1 + +package runtimeoptions diff --git a/pkg/shim/runtimeoptions/runtimeoptions_test.go b/pkg/shim/runtimeoptions/runtimeoptions_test.go deleted file mode 100644 index c59a2400e..000000000 --- a/pkg/shim/runtimeoptions/runtimeoptions_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package runtimeoptions - -import ( - "bytes" - "testing" - - shim "github.com/containerd/containerd/runtime/v1/shim/v1" - "github.com/containerd/typeurl" - "github.com/gogo/protobuf/proto" -) - -func TestCreateTaskRequest(t *testing.T) { - // Serialize the top-level message. - const encodedText = `options: < - type_url: "cri.runtimeoptions.v1.Options" - value: "\n\010type_url\022\013config_path" ->` - got := &shim.CreateTaskRequest{} // Should have raw options. - if err := proto.UnmarshalText(encodedText, got); err != nil { - t.Fatalf("unable to unmarshal text: %v", err) - } - var textBuffer bytes.Buffer - if err := proto.MarshalText(&textBuffer, got); err != nil { - t.Errorf("unable to marshal text: %v", err) - } - t.Logf("got: %s", string(textBuffer.Bytes())) - - // Check the options. - wantOptions := &Options{} - wantOptions.TypeUrl = "type_url" - wantOptions.ConfigPath = "config_path" - gotMessage, err := typeurl.UnmarshalAny(got.Options) - if err != nil { - t.Fatalf("unable to unmarshal any: %v", err) - } - gotOptions, ok := gotMessage.(*Options) - if !ok { - t.Fatalf("got %v, want %v", gotMessage, wantOptions) - } - if !proto.Equal(gotOptions, wantOptions) { - t.Fatalf("got %v, want %v", gotOptions, wantOptions) - } -} diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go deleted file mode 100644 index 2d9f07e02..000000000 --- a/pkg/shim/service_test.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shim - -import ( - "testing" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/shim/utils" -) - -func TestCgroupPath(t *testing.T) { - for _, tc := range []struct { - name string - path string - want string - }{ - { - name: "simple", - path: "foo/pod123/container", - want: "foo/pod123", - }, - { - name: "absolute", - path: "/foo/pod123/container", - want: "/foo/pod123", - }, - { - name: "no-container", - path: "foo/pod123", - want: "foo/pod123", - }, - { - name: "no-container-absolute", - path: "/foo/pod123", - want: "/foo/pod123", - }, - { - name: "double-pod", - path: "/foo/podium/pod123/container", - want: "/foo/podium/pod123", - }, - { - name: "start-pod", - path: "pod123/container", - want: "pod123", - }, - { - name: "start-pod-absolute", - path: "/pod123/container", - want: "/pod123", - }, - { - name: "slashes", - path: "///foo/////pod123//////container", - want: "/foo/pod123", - }, - { - name: "no-pod", - path: "/foo/nopod123/container", - want: "/foo/nopod123/container", - }, - } { - t.Run(tc.name, func(t *testing.T) { - spec := specs.Spec{ - Linux: &specs.Linux{ - CgroupsPath: tc.path, - }, - } - updated := updateCgroup(&spec) - if spec.Linux.CgroupsPath != tc.want { - t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath) - } - if shouldUpdate := tc.path != tc.want; shouldUpdate != updated { - t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate) - } - }) - } -} - -// Test cases that cgroup path should not be updated. -func TestCgroupNoUpdate(t *testing.T) { - for _, tc := range []struct { - name string - spec *specs.Spec - }{ - { - name: "empty", - spec: &specs.Spec{}, - }, - { - name: "subcontainer", - spec: &specs.Spec{ - Linux: &specs.Linux{ - CgroupsPath: "foo/pod123/container", - }, - Annotations: map[string]string{ - utils.ContainerTypeAnnotation: utils.ContainerTypeContainer, - }, - }, - }, - } { - t.Run(tc.name, func(t *testing.T) { - if updated := updateCgroup(tc.spec); updated { - t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated) - } - }) - } -} diff --git a/pkg/shim/shim_linux_state_autogen.go b/pkg/shim/shim_linux_state_autogen.go new file mode 100644 index 000000000..c3dd06dfe --- /dev/null +++ b/pkg/shim/shim_linux_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package shim diff --git a/pkg/shim/shim_state_autogen.go b/pkg/shim/shim_state_autogen.go new file mode 100644 index 000000000..c3dd06dfe --- /dev/null +++ b/pkg/shim/shim_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package shim diff --git a/pkg/shim/utils/BUILD b/pkg/shim/utils/BUILD deleted file mode 100644 index 2eb82f63c..000000000 --- a/pkg/shim/utils/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "utils", - srcs = [ - "annotations.go", - "errors.go", - "utils.go", - "volumes.go", - ], - visibility = [ - "//pkg/shim:__subpackages__", - "//shim:__subpackages__", - ], - deps = [ - "@com_github_containerd_containerd//errdefs:go_default_library", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - "@org_golang_google_grpc//codes:go_default_library", - "@org_golang_google_grpc//status:go_default_library", - ], -) - -go_test( - name = "utils_test", - size = "small", - srcs = [ - "errors_test.go", - "volumes_test.go", - ], - library = ":utils", - deps = [ - "@com_github_containerd_containerd//errdefs:go_default_library", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - ], -) diff --git a/pkg/shim/utils/errors_test.go b/pkg/shim/utils/errors_test.go deleted file mode 100644 index 0a8fe34c8..000000000 --- a/pkg/shim/utils/errors_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "fmt" - "testing" - - "github.com/containerd/containerd/errdefs" -) - -func TestGRPCRoundTripsErrors(t *testing.T) { - for _, tc := range []struct { - name string - err error - test func(err error) bool - }{ - { - name: "passthrough", - err: errdefs.ErrNotFound, - test: errdefs.IsNotFound, - }, - { - name: "wrapped", - err: fmt.Errorf("oh no: %w", errdefs.ErrNotFound), - test: errdefs.IsNotFound, - }, - } { - t.Run(tc.name, func(t *testing.T) { - if err := errdefs.FromGRPC(ErrToGRPC(tc.err)); !tc.test(err) { - t.Errorf("errToGRPC got %+v", err) - } - if err := errdefs.FromGRPC(ErrToGRPCf(tc.err, "testing %s", "123")); !tc.test(err) { - t.Errorf("errToGRPCf got %+v", err) - } - }) - } -} diff --git a/pkg/shim/utils/utils_state_autogen.go b/pkg/shim/utils/utils_state_autogen.go new file mode 100644 index 000000000..dba8bfb1a --- /dev/null +++ b/pkg/shim/utils/utils_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package utils diff --git a/pkg/shim/utils/volumes_test.go b/pkg/shim/utils/volumes_test.go deleted file mode 100644 index 5db43cdf1..000000000 --- a/pkg/shim/utils/volumes_test.go +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "fmt" - "io/ioutil" - "os" - "reflect" - "testing" - - specs "github.com/opencontainers/runtime-spec/specs-go" -) - -func TestUpdateVolumeAnnotations(t *testing.T) { - dir, err := ioutil.TempDir("", "test-update-volume-annotations") - if err != nil { - t.Fatalf("create tempdir: %v", err) - } - defer os.RemoveAll(dir) - kubeletPodsDir = dir - - const ( - testPodUID = "testuid" - testVolumeName = "testvolume" - testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID - testLegacyLogDirPath = "/var/log/pods/" + testPodUID - ) - testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName) - - if err := os.MkdirAll(testVolumePath, 0755); err != nil { - t.Fatalf("Create test volume: %v", err) - } - - for _, test := range []struct { - name string - spec *specs.Spec - expected *specs.Spec - expectErr bool - expectUpdate bool - }{ - { - name: "volume annotations for sandbox", - spec: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expected: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - volumeKeyPrefix + testVolumeName + ".source": testVolumePath, - }, - }, - expectUpdate: true, - }, - { - name: "volume annotations for sandbox with legacy log path", - spec: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expected: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - volumeKeyPrefix + testVolumeName + ".source": testVolumePath, - }, - }, - expectUpdate: true, - }, - { - name: "tmpfs: volume annotations for container", - spec: &specs.Spec{ - Mounts: []specs.Mount{ - { - Destination: "/test", - Type: "bind", - Source: testVolumePath, - Options: []string{"ro"}, - }, - { - Destination: "/random", - Type: "bind", - Source: "/random", - Options: []string{"ro"}, - }, - }, - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expected: &specs.Spec{ - Mounts: []specs.Mount{ - { - Destination: "/test", - Type: "tmpfs", - Source: testVolumePath, - Options: []string{"ro"}, - }, - { - Destination: "/random", - Type: "bind", - Source: "/random", - Options: []string{"ro"}, - }, - }, - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expectUpdate: true, - }, - { - name: "bind: volume annotations for container", - spec: &specs.Spec{ - Mounts: []specs.Mount{ - { - Destination: "/test", - Type: "bind", - Source: testVolumePath, - Options: []string{"ro"}, - }, - }, - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - volumeKeyPrefix + testVolumeName + ".share": "container", - volumeKeyPrefix + testVolumeName + ".type": "bind", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expected: &specs.Spec{ - Mounts: []specs.Mount{ - { - Destination: "/test", - Type: "bind", - Source: testVolumePath, - Options: []string{"ro"}, - }, - }, - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - volumeKeyPrefix + testVolumeName + ".share": "container", - volumeKeyPrefix + testVolumeName + ".type": "bind", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expectUpdate: true, - }, - { - name: "should not return error without pod log directory", - spec: &specs.Spec{ - Annotations: map[string]string{ - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - expected: &specs.Spec{ - Annotations: map[string]string{ - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - }, - }, - }, - { - name: "should return error if volume path does not exist", - spec: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - volumeKeyPrefix + "notexist.share": "pod", - volumeKeyPrefix + "notexist.type": "tmpfs", - volumeKeyPrefix + "notexist.options": "ro", - }, - }, - expectErr: true, - }, - { - name: "no volume annotations for sandbox", - spec: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - }, - }, - expected: &specs.Spec{ - Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - ContainerTypeAnnotation: containerTypeSandbox, - }, - }, - }, - { - name: "no volume annotations for container", - spec: &specs.Spec{ - Mounts: []specs.Mount{ - { - Destination: "/test", - Type: "bind", - Source: "/test", - Options: []string{"ro"}, - }, - { - Destination: "/random", - Type: "bind", - Source: "/random", - Options: []string{"ro"}, - }, - }, - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - }, - }, - expected: &specs.Spec{ - Mounts: []specs.Mount{ - { - Destination: "/test", - Type: "bind", - Source: "/test", - Options: []string{"ro"}, - }, - { - Destination: "/random", - Type: "bind", - Source: "/random", - Options: []string{"ro"}, - }, - }, - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - }, - }, - }, - { - name: "bind options removed", - spec: &specs.Spec{ - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - volumeKeyPrefix + testVolumeName + ".source": testVolumePath, - }, - Mounts: []specs.Mount{ - { - Destination: "/dst", - Type: "bind", - Source: testVolumePath, - Options: []string{"ro", "bind", "rbind"}, - }, - }, - }, - expected: &specs.Spec{ - Annotations: map[string]string{ - ContainerTypeAnnotation: ContainerTypeContainer, - volumeKeyPrefix + testVolumeName + ".share": "pod", - volumeKeyPrefix + testVolumeName + ".type": "tmpfs", - volumeKeyPrefix + testVolumeName + ".options": "ro", - volumeKeyPrefix + testVolumeName + ".source": testVolumePath, - }, - Mounts: []specs.Mount{ - { - Destination: "/dst", - Type: "tmpfs", - Source: testVolumePath, - Options: []string{"ro"}, - }, - }, - }, - expectUpdate: true, - }, - } { - t.Run(test.name, func(t *testing.T) { - updated, err := UpdateVolumeAnnotations(test.spec) - if test.expectErr { - if err == nil { - t.Fatal("Expected error, but got nil") - } - return - } - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if !reflect.DeepEqual(test.expected, test.spec) { - t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) - } - if test.expectUpdate != updated { - t.Errorf("Expected %v, got %v", test.expected, updated) - } - }) - } -} diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD deleted file mode 100644 index 48bcdd62b..000000000 --- a/pkg/sleep/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sleep", - srcs = [ - "sleep_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = ["//pkg/sync"], -) - -go_test( - name = "sleep_test", - size = "medium", - srcs = [ - "sleep_test.go", - ], - library = ":sleep", -) diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go deleted file mode 100644 index 1dd11707d..000000000 --- a/pkg/sleep/sleep_test.go +++ /dev/null @@ -1,567 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sleep - -import ( - "math/rand" - "runtime" - "testing" - "time" -) - -// ZeroWakerNotAsserted tests that a zero-value waker is in non-asserted state. -func ZeroWakerNotAsserted(t *testing.T) { - var w Waker - if w.IsAsserted() { - t.Fatalf("Zero waker is asserted") - } - - if w.Clear() { - t.Fatalf("Zero waker is asserted") - } -} - -// AssertedWakerAfterAssert tests that a waker properly reports its state as -// asserted once its Assert() method is called. -func AssertedWakerAfterAssert(t *testing.T) { - var w Waker - w.Assert() - if !w.IsAsserted() { - t.Fatalf("Asserted waker is not reported as such") - } - - if !w.Clear() { - t.Fatalf("Asserted waker is not reported as such") - } -} - -// AssertedWakerAfterTwoAsserts tests that a waker properly reports its state as -// asserted once its Assert() method is called twice. -func AssertedWakerAfterTwoAsserts(t *testing.T) { - var w Waker - w.Assert() - w.Assert() - if !w.IsAsserted() { - t.Fatalf("Asserted waker is not reported as such") - } - - if !w.Clear() { - t.Fatalf("Asserted waker is not reported as such") - } -} - -// NotAssertedWakerWithSleeper tests that a waker properly reports its state as -// not asserted after a sleeper is associated with it. -func NotAssertedWakerWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - if w.IsAsserted() { - t.Fatalf("Non-asserted waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Non-asserted waker is reported as asserted") - } -} - -// NotAssertedWakerAfterWake tests that a waker properly reports its state as -// not asserted after a previous assert is consumed by a sleeper. That is, tests -// the "edge-triggered" behavior. -func NotAssertedWakerAfterWake(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Assert() - s.Fetch(true) - if w.IsAsserted() { - t.Fatalf("Consumed waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Consumed waker is reported as asserted") - } -} - -// AssertedWakerBeforeAdd tests that a waker causes a sleeper to not sleep if -// it's already asserted before being added. -func AssertedWakerBeforeAdd(t *testing.T) { - var w Waker - var s Sleeper - w.Assert() - s.AddWaker(&w, 0) - - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though asserted waker was added") - } -} - -// ClearedWaker tests that a waker properly reports its state as not asserted -// after it is cleared. -func ClearedWaker(t *testing.T) { - var w Waker - w.Assert() - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// ClearedWakerWithSleeper tests that a waker properly reports its state as -// not asserted when it is cleared while it has a sleeper associated with it. -func ClearedWakerWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// ClearedWakerAssertedWithSleeper tests that a waker properly reports its state -// as not asserted when it is cleared while it has a sleeper associated with it -// and has been asserted. -func ClearedWakerAssertedWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Assert() - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// TestBlock tests that a sleeper actually blocks waiting for the waker to -// assert its state. -func TestBlock(t *testing.T) { - var w Waker - var s Sleeper - - s.AddWaker(&w, 0) - - // Assert waker after one second. - before := time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - - // Fetch the result and make sure it took at least 500ms. - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - if d := time.Now().Sub(before); d < 500*time.Millisecond { - t.Fatalf("Duration was too short: %v", d) - } - - // Check that already-asserted waker completes inline. - w.Assert() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - - // Check that fetch sleeps if waker had been asserted but was reset - // before Fetch is called. - w.Assert() - w.Clear() - before = time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - if d := time.Now().Sub(before); d < 500*time.Millisecond { - t.Fatalf("Duration was too short: %v", d) - } -} - -// TestNonBlock checks that a sleeper won't block if waker isn't asserted. -func TestNonBlock(t *testing.T) { - var w Waker - var s Sleeper - - // Don't block when there's no waker. - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when there is no waker") - } - - // Don't block when waker isn't asserted. - s.AddWaker(&w, 0) - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker was not asserted") - } - - // Don't block when waker was asserted, but isn't anymore. - w.Assert() - w.Clear() - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker was not asserted anymore") - } - - // Don't block when waker was consumed by previous Fetch(). - w.Assert() - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though waker was asserted") - } - - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker had been consumed") - } -} - -// TestMultiple checks that a sleeper can wait for and receives notifications -// from multiple wakers. -func TestMultiple(t *testing.T) { - s := Sleeper{} - w1 := Waker{} - w2 := Waker{} - - s.AddWaker(&w1, 0) - s.AddWaker(&w2, 1) - - w1.Assert() - w2.Assert() - - v, ok := s.Fetch(false) - if !ok { - t.Fatalf("Fetch failed when there are asserted wakers") - } - - if v != 0 && v != 1 { - t.Fatalf("Unexpected waker id: %v", v) - } - - want := 1 - v - v, ok = s.Fetch(false) - if !ok { - t.Fatalf("Fetch failed when there is an asserted waker") - } - - if v != want { - t.Fatalf("Unexpected waker id, got %v, want %v", v, want) - } -} - -// TestDoneFunction tests if calling Done() on a sleeper works properly. -func TestDoneFunction(t *testing.T) { - // Trivial case of no waker. - s := Sleeper{} - s.Done() - - // Cases when the sleeper has n wakers, but none are asserted. - for n := 1; n < 20; n++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - s.Done() - } - - // Cases when the sleeper has n wakers, and only the i-th one is - // asserted. - for n := 1; n < 20; n++ { - for i := 0; i < n; i++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - w[i].Assert() - s.Done() - } - } - - // Cases when the sleeper has n wakers, and the i-th one is asserted - // and cleared. - for n := 1; n < 20; n++ { - for i := 0; i < n; i++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - w[i].Assert() - w[i].Clear() - s.Done() - } - } - - // Cases when the sleeper has n wakers, with a random number of them - // asserted. - for n := 1; n < 20; n++ { - for iters := 0; iters < 1000; iters++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - - // Pick the number of asserted elements, then assert - // random wakers. - asserted := rand.Int() % (n + 1) - for j := 0; j < asserted; j++ { - w[rand.Int()%n].Assert() - } - s.Done() - } - } -} - -// TestRace tests that multiple wakers can continuously send wake requests to -// the sleeper. -func TestRace(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - counts := make([]int, wakers) - w := make([]Waker, wakers) - s := Sleeper{} - - // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - go func(w *Waker) { - n := 0 - for n < wakeRequests { - if !w.IsAsserted() { - w.Assert() - n++ - } else { - runtime.Gosched() - } - } - }(&w[i]) - } - - // Wait for all wake up notifications from all wakers. - for i := 0; i < wakers*wakeRequests; i++ { - v, _ := s.Fetch(true) - counts[v]++ - } - - // Check that we got the right number for each. - for i, v := range counts { - if v != wakeRequests { - t.Errorf("Waker %v only got %v wakes", i, v) - } - } -} - -// TestRaceInOrder tests that multiple wakers can continuously send wake requests to -// the sleeper and that the wakers are retrieved in the order asserted. -func TestRaceInOrder(t *testing.T) { - w := make([]Waker, 10000) - s := Sleeper{} - - // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - } - go func() { - for i := range w { - w[i].Assert() - } - }() - - // Wait for all wake up notifications from all wakers. - for want := range w { - got, _ := s.Fetch(true) - if got != want { - t.Fatalf("got %d want %d", got, want) - } - } -} - -// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up -// from 4 wakers when at least one is already asserted. -func BenchmarkSleeperMultiSelect(b *testing.B) { - const count = 4 - s := Sleeper{} - w := make([]Waker, count) - for i := range w { - s.AddWaker(&w[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w[count-1].Assert() - s.Fetch(true) - } -} - -// BenchmarkGoMultiSelect measures how long it takes to fetch a zero-length -// struct from one of 4 channels when at least one is ready. -func BenchmarkGoMultiSelect(b *testing.B) { - const count = 4 - ch := make([]chan struct{}, count) - for i := range ch { - ch[i] = make(chan struct{}, 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ch[count-1] <- struct{}{} - select { - case <-ch[0]: - case <-ch[1]: - case <-ch[2]: - case <-ch[3]: - } - } -} - -// BenchmarkSleeperSingleSelect measures how long it takes to fetch a wake up -// from one waker that is already asserted. -func BenchmarkSleeperSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - s.Fetch(true) - } -} - -// BenchmarkGoSingleSelect measures how long it takes to fetch a zero-length -// struct from a channel that already has it buffered. -func BenchmarkGoSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ch <- struct{}{} - <-ch - } -} - -// BenchmarkSleeperAssertNonWaiting measures how long it takes to assert a -// channel that is already asserted. -func BenchmarkSleeperAssertNonWaiting(b *testing.B) { - w := Waker{} - w.Assert() - for i := 0; i < b.N; i++ { - w.Assert() - } - -} - -// BenchmarkGoAssertNonWaiting measures how long it takes to write to a channel -// that has already something written to it. -func BenchmarkGoAssertNonWaiting(b *testing.B) { - ch := make(chan struct{}, 1) - ch <- struct{}{} - for i := 0; i < b.N; i++ { - select { - case ch <- struct{}{}: - default: - } - } -} - -// BenchmarkSleeperWaitOnSingleSelect measures how long it takes to wait on one -// waker channel while another goroutine wakes up the sleeper. This assumes that -// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine -// is allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - for i := 0; i < b.N; i++ { - go func() { - w.Assert() - }() - s.Fetch(true) - } - -} - -// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one -// channel while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkGoWaitOnSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - for i := 0; i < b.N; i++ { - go func() { - ch <- struct{}{} - }() - <-ch - } -} - -// BenchmarkSleeperWaitOnMultiSelect measures how long it takes to wait on 4 -// wakers while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) { - const count = 4 - s := Sleeper{} - w := make([]Waker, count) - for i := range w { - s.AddWaker(&w[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w[count-1].Assert() - }() - s.Fetch(true) - } -} - -// BenchmarkGoWaitOnMultiSelect measures how long it takes to wait on 4 channels -// while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkGoWaitOnMultiSelect(b *testing.B) { - const count = 4 - ch := make([]chan struct{}, count) - for i := range ch { - ch[i] = make(chan struct{}, 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - ch[count-1] <- struct{}{} - }() - select { - case <-ch[0]: - case <-ch[1]: - case <-ch[2]: - case <-ch[3]: - } - } -} diff --git a/pkg/sleep/sleep_unsafe_state_autogen.go b/pkg/sleep/sleep_unsafe_state_autogen.go new file mode 100644 index 000000000..bea86dcca --- /dev/null +++ b/pkg/sleep/sleep_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sleep diff --git a/pkg/state/BUILD b/pkg/state/BUILD deleted file mode 100644 index 92c51879b..000000000 --- a/pkg/state/BUILD +++ /dev/null @@ -1,86 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "deferred_list", - out = "deferred_list.go", - package = "state", - prefix = "deferred", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectEncodeState", - "ElementMapper": "deferredMapper", - "Linker": "*deferredEntry", - }, -) - -go_template_instance( - name = "complete_list", - out = "complete_list.go", - package = "state", - prefix = "complete", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectDecodeState", - "Linker": "*objectDecodeState", - }, -) - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "state", - prefix = "addr", - template = "//pkg/segment:generic_range", - types = { - "T": "uintptr", - }, -) - -go_template_instance( - name = "addr_set", - out = "addr_set.go", - consts = { - "minDegree": "10", - }, - imports = { - "reflect": "reflect", - }, - package = "state", - prefix = "addr", - template = "//pkg/segment:generic_set", - types = { - "Key": "uintptr", - "Range": "addrRange", - "Value": "*objectEncodeState", - "Functions": "addrSetFunctions", - }, -) - -go_library( - name = "state", - srcs = [ - "addr_range.go", - "addr_set.go", - "complete_list.go", - "decode.go", - "decode_unsafe.go", - "deferred_list.go", - "encode.go", - "encode_unsafe.go", - "state.go", - "state_norace.go", - "state_race.go", - "stats.go", - "types.go", - ], - marshal = False, - stateify = False, - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/state/wire", - ], -) diff --git a/pkg/state/README.md b/pkg/state/README.md deleted file mode 100644 index 1aa401193..000000000 --- a/pkg/state/README.md +++ /dev/null @@ -1,158 +0,0 @@ -# State Encoding and Decoding - -The state package implements the encoding and decoding of data structures for -`go_stateify`. This package is designed for use cases other than the standard -encoding packages, e.g. `gob` and `json`. Principally: - -* This package operates on complex object graphs and accurately serializes and - restores all relationships. That is, you can have things like: intrusive - pointers, cycles, and pointer chains of arbitrary depths. These are not - handled appropriately by existing encoders. This is not an implementation - flaw: the formats themselves are not capable of representing these graphs, - as they can only generate directed trees. - -* This package allows installing order-dependent load callbacks and then - resolves that graph at load time, with cycle detection. Similarly, there is - no analogous feature possible in the standard encoders. - -* This package handles the resolution of interfaces, based on a registered - type name. For interface objects type information is saved in the serialized - format. This is generally true for `gob` as well, but it works differently. - -Here's an overview of how encoding and decoding works. - -## Encoding - -Encoding produces a `statefile`, which contains a list of chunks of the form -`(header, payload)`. The payload can either be some raw data, or a series of -encoded wire objects representing some object graph. All encoded objects are -defined in the `wire` subpackage. - -Encoding of an object graph begins with `encodeState.Save`. - -### 1. Memory Map & Encoding - -To discover relationships between potentially interdependent data structures -(for example, a struct may contain pointers to members of other data -structures), the encoder first walks the object graph and constructs a memory -map of the objects in the input graph. As this walk progresses, objects are -queued in the `pending` list and items are placed on the `deferred` list as they -are discovered. No single object will be encoded multiple times, but the -discovered relationships between objects may change as more parts of the overall -object graph are discovered. - -The encoder starts at the root object and recursively visits all reachable -objects, recording the address ranges containing the underlying data for each -object. This is stored as a segment set (`addrSet`), mapping address ranges to -the of the object occupying the range; see `encodeState.values`. Note that there -is special handling for zero-sized types and map objects during this process. - -Additionally, the encoder assigns each object a unique identifier which is used -to indicate relationships between objects in the statefile; see `objectID` in -`encode.go`. - -### 2. Type Serialization - -The enoder will subsequently serialize all information about discovered types, -including field names. These are used during decoding to reconcile these types -with other internally registered types. - -### 3. Object Serialization - -With a full address map, and all objects correctly encoded, all object encodings -are serialized. The assigned `objectID`s aren't explicitly encoded in the -statefile. The order of object messages in the stream determine their IDs. - -### Example - -Given the following data structure definitions: - -```go -type system struct { - o *outer - i *inner -} - -type outer struct { - a int64 - cn *container -} - -type container struct { - n uint64 - elem *inner -} - -type inner struct { - c container - x, y uint64 -} -``` - -Initialized like this: - -```go -o := outer{ - a: 10, - cn: nil, -} -i := inner{ - x: 20, - y: 30, - c: container{}, -} -s := system{ - o: &o, - i: &i, -} - -o.cn = &i.c -o.cn.elem = &i - -``` - -Encoding will produce an object stream like this: - -``` -g0r1 = struct{ - i: g0r3, - o: g0r2, -} -g0r2 = struct{ - a: 10, - cn: g0r3.c, -} -g0r3 = struct{ - c: struct{ - elem: g0r3, - n: 0u, - }, - x: 20u, - y: 30u, -} -``` - -Note how `g0r3.c` is correctly encoded as the underlying `container` object for -`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i` -being discovered after the pointer to it in `system.o.cn`. Also note that -decoding isn't strictly reliant on the order of encoded object stream, as long -as the relationship between objects are correctly encoded. - -## Decoding - -Decoding reads the statefile and reconstructs the object graph. Decoding begins -in `decodeState.Load`. Decoding is performed in a single pass over the object -stream in the statefile, and a subsequent pass over all deserialized objects is -done to fire off all loading callbacks in the correctly defined order. Note that -introducing cycles is possible here, but these are detected and an error will be -returned. - -Decoding is relatively straight forward. For most primitive values, the decoder -constructs an appropriate object and fills it with the values encoded in the -statefile. Pointers need special handling, as they must point to a value -allocated elsewhere. When values are constructed, the decoder indexes them by -their `objectID`s in `decodeState.objectsByID`. The target of pointers are -resolved by searching for the target in this index by their `objectID`; see -`decodeState.register`. For pointers to values inside another value (fields in a -pointer, elements of an array), the decoder uses the accessor path to walk to -the appropriate location; see `walkChild`. diff --git a/pkg/state/addr_range.go b/pkg/state/addr_range.go new file mode 100644 index 000000000..0b7346e47 --- /dev/null +++ b/pkg/state/addr_range.go @@ -0,0 +1,76 @@ +package state + +// A Range represents a contiguous range of T. +// +// +stateify savable +type addrRange struct { + // Start is the inclusive start of the range. + Start uintptr + + // End is the exclusive end of the range. + End uintptr +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +// +//go:nosplit +func (r addrRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +// +//go:nosplit +func (r addrRange) Length() uintptr { + return r.End - r.Start +} + +// Contains returns true if r contains x. +// +//go:nosplit +func (r addrRange) Contains(x uintptr) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +// +//go:nosplit +func (r addrRange) Overlaps(r2 addrRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +// +//go:nosplit +func (r addrRange) IsSupersetOf(r2 addrRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +// +//go:nosplit +func (r addrRange) Intersect(r2 addrRange) addrRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +// +//go:nosplit +func (r addrRange) CanSplitAt(x uintptr) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/state/addr_set.go b/pkg/state/addr_set.go new file mode 100644 index 000000000..8abc3dd34 --- /dev/null +++ b/pkg/state/addr_set.go @@ -0,0 +1,1643 @@ +package state + +import ( + "bytes" + "fmt" +) + +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const addrtrackGaps = 0 + +var _ = uint8(addrtrackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type addrdynamicGap [addrtrackGaps]uintptr + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *addrdynamicGap) Get() uintptr { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *addrdynamicGap) Set(v uintptr) { + d[:][0] = v +} + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + addrminDegree = 10 + + addrmaxDegree = 2 * addrminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type addrSet struct { + root addrnode `state:".(*addrSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *addrSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *addrSet) IsEmptyRange(r addrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *addrSet) Span() uintptr { + var sz uintptr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *addrSet) SpanRange(r addrRange) uintptr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uintptr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *addrSet) FirstSegment() addrIterator { + if s.root.nrSegments == 0 { + return addrIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *addrSet) LastSegment() addrIterator { + if s.root.nrSegments == 0 { + return addrIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *addrSet) FirstGap() addrGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return addrGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *addrSet) LastGap() addrGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return addrGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *addrSet) Find(key uintptr) (addrIterator, addrGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return addrIterator{n, i}, addrGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return addrIterator{}, addrGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *addrSet) FindSegment(key uintptr) addrIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *addrSet) LowerBoundSegment(min uintptr) addrIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *addrSet) UpperBoundSegment(max uintptr) addrIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *addrSet) FindGap(key uintptr) addrGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *addrSet) LowerBoundGap(min uintptr) addrGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *addrSet) UpperBoundGap(max uintptr) addrGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *addrSet) Add(r addrRange, val *objectEncodeState) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *addrSet) AddWithoutMerging(r addrRange, val *objectEncodeState) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *addrSet) Insert(gap addrGapIterator, r addrRange, val *objectEncodeState) addrIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := addrtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (addrSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := addrtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } + return next + } + } + + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *addrSet) InsertWithoutMerging(gap addrGapIterator, r addrRange, val *objectEncodeState) addrIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). +func (s *addrSet) InsertWithoutMergingUnchecked(gap addrGapIterator, r addrRange, val *objectEncodeState) addrIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := addrtrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } + return addrIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *addrSet) Remove(seg addrIterator) addrGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + + nextAdjacentNode := seg.NextSegment().node + if addrtrackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + addrSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + if addrtrackGaps != 0 { + seg.node.updateMaxGapLeaf() + } + return seg.node.rebalanceAfterRemove(addrGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *addrSet) RemoveAll() { + s.root = addrnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *addrSet) RemoveRange(r addrRange) addrGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *addrSet) Merge(first, second addrIterator) addrIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *addrSet) MergeUnchecked(first, second addrIterator) addrIterator { + if first.End() == second.Start() { + if mval, ok := (addrSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + + return s.Remove(second).PrevSegment() + } + } + return addrIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *addrSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *addrSet) MergeRange(r addrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *addrSet) MergeAdjacent(r addrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *addrSet) Split(seg addrIterator, split uintptr) (addrIterator, addrIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *addrSet) SplitUnchecked(seg addrIterator, split uintptr) (addrIterator, addrIterator) { + val1, val2 := (addrSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), addrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *addrSet) SplitAt(split uintptr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *addrSet) Isolate(seg addrIterator, r addrRange) addrIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *addrSet) ApplyContiguous(r addrRange, fn func(seg addrIterator)) addrGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return addrGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return addrGapIterator{} + } + } +} + +// +stateify savable +type addrnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *addrnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap addrdynamicGap + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [addrmaxDegree - 1]addrRange + values [addrmaxDegree - 1]*objectEncodeState + children [addrmaxDegree]*addrnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *addrnode) firstSegment() addrIterator { + for n.hasChildren { + n = n.children[0] + } + return addrIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *addrnode) lastSegment() addrIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return addrIterator{n, n.nrSegments - 1} +} + +func (n *addrnode) prevSibling() *addrnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *addrnode) nextSibling() *addrnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *addrnode) rebalanceBeforeInsert(gap addrGapIterator) addrGapIterator { + if n.nrSegments < addrmaxDegree-1 { + return gap + } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.parent == nil { + + left := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:addrminDegree-1], n.keys[:addrminDegree-1]) + copy(left.values[:addrminDegree-1], n.values[:addrminDegree-1]) + copy(right.keys[:addrminDegree-1], n.keys[addrminDegree:]) + copy(right.values[:addrminDegree-1], n.values[addrminDegree:]) + n.keys[0], n.values[0] = n.keys[addrminDegree-1], n.values[addrminDegree-1] + addrzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:addrminDegree], n.children[:addrminDegree]) + copy(right.children[:addrminDegree], n.children[addrminDegree:]) + addrzeroNodeSlice(n.children[2:]) + for i := 0; i < addrminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + + if addrtrackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } + if gap.node != n { + return gap + } + if gap.index < addrminDegree { + return addrGapIterator{left, gap.index} + } + return addrGapIterator{right, gap.index - addrminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[addrminDegree-1], n.values[addrminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:addrminDegree-1], n.keys[addrminDegree:]) + copy(sibling.values[:addrminDegree-1], n.values[addrminDegree:]) + addrzeroValueSlice(n.values[addrminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:addrminDegree], n.children[addrminDegree:]) + addrzeroNodeSlice(n.children[addrminDegree:]) + for i := 0; i < addrminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = addrminDegree - 1 + + if addrtrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + + if gap.node != n { + return gap + } + if gap.index < addrminDegree { + return gap + } + return addrGapIterator{sibling, gap.index - addrminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *addrnode) rebalanceAfterRemove(gap addrGapIterator) addrGapIterator { + for { + if n.nrSegments >= addrminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= addrminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if addrtrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling && gap.index == sibling.nrSegments { + return addrGapIterator{n, 0} + } + if gap.node == n { + return addrGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= addrminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + + if addrtrackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } + if gap.node == sibling { + if gap.index == 0 { + return addrGapIterator{n, n.nrSegments} + } + return addrGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + + if gap.node == left { + return addrGapIterator{p, gap.index} + } + if gap.node == right { + return addrGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *addrnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = addrGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + addrSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + if addrtrackGaps != 0 { + left.updateMaxGapLocal() + } + + n = p + } +} + +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *addrnode) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + + break + } + + p.maxGap.Set(max) + } + return + } + + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + + break + } + + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + + break + } + + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *addrnode) updateMaxGapLocal() { + if !n.hasChildren { + + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *addrnode) calculateMaxGapLeaf() uintptr { + max := addrGapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (addrGapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *addrnode) calculateMaxGapInternal() uintptr { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *addrnode) searchFirstLargeEnoughGap(minSize uintptr) addrGapIterator { + if n.maxGap.Get() < minSize { + return addrGapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := addrGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *addrnode) searchLastLargeEnoughGap(minSize uintptr) addrGapIterator { + if n.maxGap.Get() < minSize { + return addrGapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := addrGapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type addrIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *addrnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg addrIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg addrIterator) Range() addrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg addrIterator) Start() uintptr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg addrIterator) End() uintptr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). +func (seg addrIterator) SetRangeUnchecked(r addrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetRange(r addrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg addrIterator) SetStartUnchecked(start uintptr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetStart(start uintptr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg addrIterator) SetEndUnchecked(end uintptr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetEnd(end uintptr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg addrIterator) Value() *objectEncodeState { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg addrIterator) ValuePtr() **objectEncodeState { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg addrIterator) SetValue(val *objectEncodeState) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg addrIterator) PrevSegment() addrIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return addrIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return addrIterator{} + } + return addrsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg addrIterator) NextSegment() addrIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return addrIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return addrIterator{} + } + return addrsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg addrIterator) PrevGap() addrGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return addrGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg addrIterator) NextGap() addrGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return addrGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg addrIterator) PrevNonEmpty() (addrIterator, addrGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return addrIterator{}, gap + } + return gap.PrevSegment(), addrGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg addrIterator) NextNonEmpty() (addrIterator, addrGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return addrIterator{}, gap + } + return gap.NextSegment(), addrGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type addrGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *addrnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap addrGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap addrGapIterator) Range() addrRange { + return addrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap addrGapIterator) Start() uintptr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return addrSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap addrGapIterator) End() uintptr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return addrSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap addrGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap addrGapIterator) PrevSegment() addrIterator { + return addrsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap addrGapIterator) NextSegment() addrIterator { + return addrsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap addrGapIterator) PrevGap() addrGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return addrGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap addrGapIterator) NextGap() addrGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return addrGapIterator{} + } + return seg.NextGap() +} + +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap addrGapIterator) NextLargeEnoughGap(minSize uintptr) addrGapIterator { + if addrtrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap addrGapIterator) nextLargeEnoughGapHelper(minSize uintptr) addrGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return addrGapIterator{} + } + + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap addrGapIterator) PrevLargeEnoughGap(minSize uintptr) addrGapIterator { + if addrtrackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap addrGapIterator) prevLargeEnoughGapHelper(minSize uintptr) addrGapIterator { + + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + + if gap.node == nil { + return addrGapIterator{} + } + + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func addrsegmentBeforePosition(n *addrnode, i int) addrIterator { + for i == 0 { + if n.parent == nil { + return addrIterator{} + } + n, i = n.parent, n.parentIndex + } + return addrIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func addrsegmentAfterPosition(n *addrnode, i int) addrIterator { + for i == n.nrSegments { + if n.parent == nil { + return addrIterator{} + } + n, i = n.parent, n.parentIndex + } + return addrIterator{n, i} +} + +func addrzeroValueSlice(slice []*objectEncodeState) { + + for i := range slice { + addrSetFunctions{}.ClearValue(&slice[i]) + } +} + +func addrzeroNodeSlice(slice []*addrnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *addrSet) String() string { + return s.root.String() +} + +// String stringifies a node (and all of its children) for debugging. +func (n *addrnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *addrnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + if n.hasChildren { + if addrtrackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type addrSegmentDataSlices struct { + Start []uintptr + End []uintptr + Values []*objectEncodeState +} + +// ExportSortedSlices returns a copy of all segments in the given set, in +// ascending key order. +func (s *addrSet) ExportSortedSlices() *addrSegmentDataSlices { + var sds addrSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlices initializes the given set from the given slice. +// +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. +func (s *addrSet) ImportSortedSlices(sds *addrSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := addrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *addrSet) segmentTestCheck(expectedSegments int, segFunc func(int, addrRange, *objectEncodeState) error) error { + havePrev := false + prev := uintptr(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *addrSet) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} +func (s *addrSet) saveRoot() *addrSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *addrSet) loadRoot(sds *addrSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/state/complete_list.go b/pkg/state/complete_list.go new file mode 100644 index 000000000..4d340a1af --- /dev/null +++ b/pkg/state/complete_list.go @@ -0,0 +1,221 @@ +package state + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type completeElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (completeElementMapper) linkerFor(elem *objectDecodeState) *objectDecodeState { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type completeList struct { + head *objectDecodeState + tail *objectDecodeState +} + +// Reset resets list l to the empty state. +func (l *completeList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *completeList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *completeList) Front() *objectDecodeState { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *completeList) Back() *objectDecodeState { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *completeList) Len() (count int) { + for e := l.Front(); e != nil; e = (completeElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *completeList) PushFront(e *objectDecodeState) { + linker := completeElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + completeElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *completeList) PushBack(e *objectDecodeState) { + linker := completeElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + completeElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *completeList) PushBackList(m *completeList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + completeElementMapper{}.linkerFor(l.tail).SetNext(m.head) + completeElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *completeList) InsertAfter(b, e *objectDecodeState) { + bLinker := completeElementMapper{}.linkerFor(b) + eLinker := completeElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + completeElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *completeList) InsertBefore(a, e *objectDecodeState) { + aLinker := completeElementMapper{}.linkerFor(a) + eLinker := completeElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + completeElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *completeList) Remove(e *objectDecodeState) { + linker := completeElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + completeElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + completeElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type completeEntry struct { + next *objectDecodeState + prev *objectDecodeState +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *completeEntry) Next() *objectDecodeState { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *completeEntry) Prev() *objectDecodeState { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *completeEntry) SetNext(elem *objectDecodeState) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *completeEntry) SetPrev(elem *objectDecodeState) { + e.prev = elem +} diff --git a/pkg/state/deferred_list.go b/pkg/state/deferred_list.go new file mode 100644 index 000000000..2753ce4a2 --- /dev/null +++ b/pkg/state/deferred_list.go @@ -0,0 +1,206 @@ +package state + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type deferredList struct { + head *objectEncodeState + tail *objectEncodeState +} + +// Reset resets list l to the empty state. +func (l *deferredList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *deferredList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *deferredList) Front() *objectEncodeState { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *deferredList) Back() *objectEncodeState { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *deferredList) Len() (count int) { + for e := l.Front(); e != nil; e = (deferredMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *deferredList) PushFront(e *objectEncodeState) { + linker := deferredMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + deferredMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *deferredList) PushBack(e *objectEncodeState) { + linker := deferredMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + deferredMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *deferredList) PushBackList(m *deferredList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + deferredMapper{}.linkerFor(l.tail).SetNext(m.head) + deferredMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *deferredList) InsertAfter(b, e *objectEncodeState) { + bLinker := deferredMapper{}.linkerFor(b) + eLinker := deferredMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + deferredMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *deferredList) InsertBefore(a, e *objectEncodeState) { + aLinker := deferredMapper{}.linkerFor(a) + eLinker := deferredMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + deferredMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *deferredList) Remove(e *objectEncodeState) { + linker := deferredMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + deferredMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + deferredMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type deferredEntry struct { + next *objectEncodeState + prev *objectEncodeState +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *deferredEntry) Next() *objectEncodeState { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *deferredEntry) Prev() *objectEncodeState { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *deferredEntry) SetNext(elem *objectEncodeState) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *deferredEntry) SetPrev(elem *objectEncodeState) { + e.prev = elem +} diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD deleted file mode 100644 index d053802f7..000000000 --- a/pkg/state/pretty/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "pretty", - srcs = ["pretty.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/state", - "//pkg/state/wire", - ], -) diff --git a/pkg/state/pretty/pretty_state_autogen.go b/pkg/state/pretty/pretty_state_autogen.go new file mode 100644 index 000000000..e772e34a4 --- /dev/null +++ b/pkg/state/pretty/pretty_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pretty diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD deleted file mode 100644 index 08d06e37b..000000000 --- a/pkg/state/statefile/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "statefile", - srcs = ["statefile.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/compressio", - "//pkg/state/wire", - ], -) - -go_test( - name = "statefile_test", - size = "small", - srcs = ["statefile_test.go"], - library = ":statefile", - deps = ["//pkg/compressio"], -) diff --git a/pkg/state/statefile/statefile_state_autogen.go b/pkg/state/statefile/statefile_state_autogen.go new file mode 100644 index 000000000..a2cdaa3f1 --- /dev/null +++ b/pkg/state/statefile/statefile_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package statefile diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go deleted file mode 100644 index 0b470fdec..000000000 --- a/pkg/state/statefile/statefile_test.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statefile - -import ( - "bytes" - crand "crypto/rand" - "encoding/base64" - "io" - "math/rand" - "runtime" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/compressio" -) - -func randomKey() ([]byte, error) { - r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize)) - if _, err := io.ReadFull(crand.Reader, r); err != nil { - return nil, err - } - key := make([]byte, keySize) - base64.RawStdEncoding.Encode(key, r) - return key, nil -} - -type testCase struct { - name string - data []byte - metadata map[string]string -} - -func TestStatefile(t *testing.T) { - rand.Seed(time.Now().Unix()) - - cases := []testCase{ - // Various data sizes. - {"nil", nil, nil}, - {"empty", []byte(""), nil}, - {"some", []byte("_"), nil}, - {"one", []byte("0"), nil}, - {"two", []byte("01"), nil}, - {"three", []byte("012"), nil}, - {"four", []byte("0123"), nil}, - {"five", []byte("01234"), nil}, - {"six", []byte("012356"), nil}, - {"seven", []byte("0123567"), nil}, - {"eight", []byte("01235678"), nil}, - - // Make sure we have one longer than the hash length. - {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, - - // Make sure we have one longer than the chunk size. - {"chunks", make([]byte, 3*compressionChunkSize), nil}, - {"large", make([]byte, 30*compressionChunkSize), nil}, - - // Different metadata. - {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, - {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}}, - } - - for _, c := range cases { - // Generate a key. - integrityKey, err := randomKey() - if err != nil { - t.Errorf("can't generate key: got %v, excepted nil", err) - continue - } - - t.Run(c.name, func(t *testing.T) { - for _, key := range [][]byte{nil, integrityKey} { - t.Run("key="+string(key), func(t *testing.T) { - // Encoding happens via a buffer. - var bufEncoded bytes.Buffer - var bufDecoded bytes.Buffer - - // Do all the writing. - w, err := NewWriter(&bufEncoded, key, c.metadata) - if err != nil { - t.Fatalf("error creating writer: got %v, expected nil", err) - } - if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil { - t.Fatalf("error during write: got %v, expected nil", err) - } - - // Finish the sum. - if err := w.Close(); err != nil { - t.Fatalf("error during close: got %v, expected nil", err) - } - - t.Logf("original data: %d bytes, encoded: %d bytes.", - len(c.data), len(bufEncoded.Bytes())) - - // Do all the reading. - r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key) - if err != nil { - t.Fatalf("error creating reader: got %v, expected nil", err) - } - if _, err := io.Copy(&bufDecoded, r); err != nil { - t.Fatalf("error during read: got %v, expected nil", err) - } - - // Check that the data matches. - if !bytes.Equal(c.data, bufDecoded.Bytes()) { - t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data)) - } - - // Check that the metadata matches. - for k, v := range c.metadata { - nv, ok := metadata[k] - if !ok { - t.Fatalf("missing metadata: %s", k) - } - if v != nv { - t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v) - } - } - - // Change the data and verify that it fails. - if key != nil { - b := append([]byte(nil), bufEncoded.Bytes()...) - b[rand.Intn(len(b))]++ - bufDecoded.Reset() - r, _, err = NewReader(bytes.NewReader(b), key) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err == nil { - t.Error("got no error: expected error on data corruption") - } - } - - // Change the key and verify that it fails. - newKey := integrityKey - if len(key) > 0 { - newKey = append([]byte{}, key...) - newKey[rand.Intn(len(newKey))]++ - } - bufDecoded.Reset() - r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err != compressio.ErrHashMismatch { - t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) - } - }) - } - }) - } -} - -const benchmarkDataSize = 100 * 1024 * 1024 - -func benchmark(b *testing.B, size int, write bool, compressible bool) { - b.StopTimer() - b.SetBytes(benchmarkDataSize) - - // Generate source data. - var source []byte - if compressible { - // For compressible data, we use essentially all zeros. - source = make([]byte, benchmarkDataSize) - } else { - // For non-compressible data, we use random base64 data (to - // make it marginally compressible, a ratio of 75%). - var sourceBuf bytes.Buffer - bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf) - bufR := rand.New(rand.NewSource(0)) - if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil { - b.Fatalf("unable to seed random data: %v", err) - } - source = sourceBuf.Bytes() - } - - // Generate a random key for integrity check. - key, err := randomKey() - if err != nil { - b.Fatalf("error generating key: %v", err) - } - - // Define our benchmark functions. Prior to running the readState - // function here, you must execute the writeState function at least - // once (done below). - var stateBuf bytes.Buffer - writeState := func() { - stateBuf.Reset() - w, err := NewWriter(&stateBuf, key, nil) - if err != nil { - b.Fatalf("error creating writer: %v", err) - } - for done := 0; done < len(source); { - chunk := size // limit size. - if done+chunk > len(source) { - chunk = len(source) - done - } - n, err := w.Write(source[done : done+chunk]) - done += n - if n == 0 && err != nil { - b.Fatalf("error during write: %v", err) - } - } - if err := w.Close(); err != nil { - b.Fatalf("error closing writer: %v", err) - } - } - readState := func() { - tmpBuf := bytes.NewBuffer(stateBuf.Bytes()) - r, _, err := NewReader(tmpBuf, key) - if err != nil { - b.Fatalf("error creating reader: %v", err) - } - for done := 0; done < len(source); { - chunk := size // limit size. - if done+chunk > len(source) { - chunk = len(source) - done - } - n, err := r.Read(source[done : done+chunk]) - done += n - if n == 0 && err != nil { - b.Fatalf("error during read: %v", err) - } - } - } - // Generate the state once without timing to ensure that buffers have - // been appropriately allocated. - writeState() - if write { - b.StartTimer() - for i := 0; i < b.N; i++ { - writeState() - } - b.StopTimer() - } else { - b.StartTimer() - for i := 0; i < b.N; i++ { - readState() - } - b.StopTimer() - } -} - -func BenchmarkWrite4KCompressible(b *testing.B) { - benchmark(b, 4096, true, true) -} - -func BenchmarkWrite4KNoncompressible(b *testing.B) { - benchmark(b, 4096, true, false) -} - -func BenchmarkWrite1MCompressible(b *testing.B) { - benchmark(b, 1024*1024, true, true) -} - -func BenchmarkWrite1MNoncompressible(b *testing.B) { - benchmark(b, 1024*1024, true, false) -} - -func BenchmarkRead4KCompressible(b *testing.B) { - benchmark(b, 4096, false, true) -} - -func BenchmarkRead4KNoncompressible(b *testing.B) { - benchmark(b, 4096, false, false) -} - -func BenchmarkRead1MCompressible(b *testing.B) { - benchmark(b, 1024*1024, false, true) -} - -func BenchmarkRead1MNoncompressible(b *testing.B) { - benchmark(b, 1024*1024, false, false) -} - -func init() { - runtime.GOMAXPROCS(runtime.NumCPU()) -} diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD deleted file mode 100644 index 9297cafbe..000000000 --- a/pkg/state/tests/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tests", - srcs = [ - "array.go", - "bench.go", - "integer.go", - "load.go", - "map.go", - "register.go", - "struct.go", - "tests.go", - ], - deps = [ - "//pkg/state", - "//pkg/state/pretty", - ], -) - -go_test( - name = "tests_test", - size = "small", - srcs = [ - "array_test.go", - "bench_test.go", - "bool_test.go", - "float_test.go", - "integer_test.go", - "load_test.go", - "map_test.go", - "register_test.go", - "string_test.go", - "struct_test.go", - ], - library = ":tests", - deps = [ - "//pkg/state", - "//pkg/state/wire", - ], -) diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go deleted file mode 100644 index 0972a80e7..000000000 --- a/pkg/state/tests/array.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -// +stateify savable -type arrayContainer struct { - v [1]interface{} -} - -// +stateify savable -type arrayPtrContainer struct { - v *[1]interface{} -} - -// +stateify savable -type sliceContainer struct { - v []interface{} -} - -// +stateify savable -type slicePtrContainer struct { - v *[]interface{} -} diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go deleted file mode 100644 index a347b2947..000000000 --- a/pkg/state/tests/array_test.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "reflect" - "testing" -) - -var allArrayPrimitives = []interface{}{ - [1]bool{}, - [1]bool{true}, - [2]bool{false, true}, - [1]int{}, - [1]int{1}, - [2]int{0, 1}, - [1]int8{}, - [1]int8{1}, - [2]int8{0, 1}, - [1]int16{}, - [1]int16{1}, - [2]int16{0, 1}, - [1]int32{}, - [1]int32{1}, - [2]int32{0, 1}, - [1]int64{}, - [1]int64{1}, - [2]int64{0, 1}, - [1]uint{}, - [1]uint{1}, - [2]uint{0, 1}, - [1]uintptr{}, - [1]uintptr{1}, - [2]uintptr{0, 1}, - [1]uint8{}, - [1]uint8{1}, - [2]uint8{0, 1}, - [1]uint16{}, - [1]uint16{1}, - [2]uint16{0, 1}, - [1]uint32{}, - [1]uint32{1}, - [2]uint32{0, 1}, - [1]uint64{}, - [1]uint64{1}, - [2]uint64{0, 1}, - [1]string{}, - [1]string{""}, - [1]string{nonEmptyString}, - [2]string{"", nonEmptyString}, -} - -func TestArrayPrimitives(t *testing.T) { - runTestCases(t, false, "plain", flatten(allArrayPrimitives)) - runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives))) - runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives))) - runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives)))) -} - -func TestSlices(t *testing.T) { - var allSlices = flatten( - filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { - v := reflect.New(reflect.TypeOf(o)).Elem() - v.Set(reflect.ValueOf(o)) - return v.Slice(0, v.Len()).Interface(), true - }), - filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { - v := reflect.New(reflect.TypeOf(o)).Elem() - v.Set(reflect.ValueOf(o)) - if v.Len() == 0 { - // Return the pure "nil" value for the slice. - return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true - } - return v.Slice(1, v.Len()).Interface(), true - }), - filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { - v := reflect.New(reflect.TypeOf(o)).Elem() - v.Set(reflect.ValueOf(o)) - if v.Len() == 0 { - // Return the zero-valued slice. - return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true - } - return v.Slice(0, v.Len()-1).Interface(), true - }), - ) - runTestCases(t, false, "plain", allSlices) - runTestCases(t, false, "pointers", pointersTo(allSlices)) - runTestCases(t, false, "interfaces", interfacesTo(allSlices)) - runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices))) -} - -func TestArrayContainers(t *testing.T) { - var ( - emptyArray [1]interface{} - fullArray [1]interface{} - ) - fullArray[0] = &emptyArray - runTestCases(t, false, "", []interface{}{ - arrayContainer{v: emptyArray}, - arrayContainer{v: fullArray}, - arrayPtrContainer{v: nil}, - arrayPtrContainer{v: &emptyArray}, - arrayPtrContainer{v: &fullArray}, - }) -} - -func TestSliceContainers(t *testing.T) { - var ( - nilSlice []interface{} - emptySlice = make([]interface{}, 0) - fullSlice = []interface{}{nil} - ) - runTestCases(t, false, "", []interface{}{ - sliceContainer{v: nilSlice}, - sliceContainer{v: emptySlice}, - sliceContainer{v: fullSlice}, - slicePtrContainer{v: nil}, - slicePtrContainer{v: &nilSlice}, - slicePtrContainer{v: &emptySlice}, - slicePtrContainer{v: &fullSlice}, - }) -} diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go deleted file mode 100644 index 40869cdfb..000000000 --- a/pkg/state/tests/bench.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -// +stateify savable -type benchStruct struct { - B *benchStruct // Must be exported for gob. -} - -func (b *benchStruct) afterLoad() { - // Do nothing, just force scheduling. -} diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go deleted file mode 100644 index 7e102c907..000000000 --- a/pkg/state/tests/bench_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "bytes" - "context" - "encoding/gob" - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/state" - "gvisor.dev/gvisor/pkg/state/wire" -) - -// buildPtrObject builds a benchmark object. -func buildPtrObject(n int) interface{} { - b := new(benchStruct) - for i := 0; i < n; i++ { - b = &benchStruct{B: b} - } - return b -} - -// buildMapObject builds a benchmark object. -func buildMapObject(n int) interface{} { - b := new(benchStruct) - m := make(map[int]*benchStruct) - for i := 0; i < n; i++ { - m[i] = b - } - return &m -} - -// buildSliceObject builds a benchmark object. -func buildSliceObject(n int) interface{} { - b := new(benchStruct) - s := make([]*benchStruct, 0, n) - for i := 0; i < n; i++ { - s = append(s, b) - } - return &s -} - -var allObjects = map[string]struct { - New func(int) interface{} -}{ - "ptr": { - New: buildPtrObject, - }, - "map": { - New: buildMapObject, - }, - "slice": { - New: buildSliceObject, - }, -} - -func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) { - // maxSize is the maximum size of an individual object below. For an N - // larger than this, we start to return multiple objects. - const maxSize = 1024 - if n <= maxSize { - return 1, fn(n) - } - iters = (n + maxSize - 1) / maxSize - return iters, fn(maxSize) -} - -// gobSave is a version of save using gob (no stats available). -func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) { - enc := gob.NewEncoder(w) - err = enc.Encode(v) - return -} - -// gobLoad is a version of load using gob (no stats available). -func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) { - dec := gob.NewDecoder(r) - err = dec.Decode(v) - return -} - -var allAlgos = map[string]struct { - Save func(context.Context, wire.Writer, interface{}) (state.Stats, error) - Load func(context.Context, wire.Reader, interface{}) (state.Stats, error) - MaxPtr int -}{ - "state": { - Save: state.Save, - Load: state.Load, - }, - "gob": { - Save: gobSave, - Load: gobLoad, - }, -} - -func BenchmarkEncoding(b *testing.B) { - for objName, objInfo := range allObjects { - for algoName, algoInfo := range allAlgos { - b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { - b.StopTimer() - n, v := buildObjects(b.N, objInfo.New) - b.ReportAllocs() - b.StartTimer() - for i := 0; i < n; i++ { - if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil { - b.Errorf("save failed: %v", err) - } - } - b.StopTimer() - }) - } - } -} - -func BenchmarkDecoding(b *testing.B) { - for objName, objInfo := range allObjects { - for algoName, algoInfo := range allAlgos { - b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { - b.StopTimer() - n, v := buildObjects(b.N, objInfo.New) - buf := new(bytes.Buffer) - if _, err := algoInfo.Save(context.Background(), buf, v); err != nil { - b.Errorf("save failed: %v", err) - } - b.ReportAllocs() - b.StartTimer() - var r bytes.Reader - for i := 0; i < n; i++ { - r.Reset(buf.Bytes()) - if _, err := algoInfo.Load(context.Background(), &r, v); err != nil { - b.Errorf("load failed: %v", err) - } - } - b.StopTimer() - }) - } - } -} diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go deleted file mode 100644 index e17cfacf9..000000000 --- a/pkg/state/tests/bool_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "testing" -) - -var allBools = []bool{ - true, - false, -} - -func TestBool(t *testing.T) { - runTestCases(t, false, "plain", flatten(allBools)) - runTestCases(t, false, "pointers", pointersTo(flatten(allBools))) - runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools))) - runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools)))) -} diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go deleted file mode 100644 index 3e89edd9c..000000000 --- a/pkg/state/tests/float_test.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "math" - "testing" -) - -var safeFloat32s = []float32{ - float32(0.0), - float32(1.0), - float32(-1.0), - float32(math.Inf(1)), - float32(math.Inf(-1)), -} - -var allFloat32s = append(safeFloat32s, float32(math.NaN())) - -var safeFloat64s = []float64{ - float64(0.0), - float64(1.0), - float64(-1.0), - math.Inf(1), - math.Inf(-1), -} - -var allFloat64s = append(safeFloat64s, math.NaN()) - -func TestFloat(t *testing.T) { - runTestCases(t, false, "plain", flatten( - allFloat32s, - allFloat64s, - )) - // See checkEqual for why NaNs are missing. - runTestCases(t, false, "pointers", pointersTo(flatten( - safeFloat32s, - safeFloat64s, - ))) - runTestCases(t, false, "interfaces", interfacesTo(flatten( - safeFloat32s, - safeFloat64s, - ))) - runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten( - safeFloat32s, - safeFloat64s, - )))) -} - -const onlyDouble float64 = 1.0000000000000002 - -func TestFloatTruncation(t *testing.T) { - runTestCases(t, true, "pass", []interface{}{ - truncatingFloat32{save: onlyDouble}, - }) - runTestCases(t, false, "fail", []interface{}{ - truncatingFloat32{save: 1.0}, - }) -} - -var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} { - return complex(i.(float32), j.(float32)) -}) - -var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} { - return complex(i.(float32), j.(float32)) -}) - -var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} { - return complex(i.(float64), j.(float64)) -}) - -var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} { - return complex(i.(float64), j.(float64)) -}) - -func TestComplex(t *testing.T) { - runTestCases(t, false, "plain", flatten( - allComplex64s, - allComplex128s, - )) - // See TestFloat; same issue. - runTestCases(t, false, "pointers", pointersTo(flatten( - safeComplex64s, - safeComplex128s, - ))) - runTestCases(t, false, "interfacse", interfacesTo(flatten( - safeComplex64s, - safeComplex128s, - ))) - runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten( - safeComplex64s, - safeComplex128s, - )))) -} - -func TestComplexTruncation(t *testing.T) { - runTestCases(t, true, "pass", []interface{}{ - truncatingComplex64{save: complex(onlyDouble, onlyDouble)}, - truncatingComplex64{save: complex(1.0, onlyDouble)}, - truncatingComplex64{save: complex(onlyDouble, 1.0)}, - }) - runTestCases(t, false, "fail", []interface{}{ - truncatingComplex64{save: complex(1.0, 1.0)}, - }) -} diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go deleted file mode 100644 index ca403eed1..000000000 --- a/pkg/state/tests/integer.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "gvisor.dev/gvisor/pkg/state" -) - -// +stateify type -type truncatingUint8 struct { - save uint64 - load uint8 `state:"nosave"` -} - -func (t *truncatingUint8) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingUint8) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = uint64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingUint8)(nil) - -// +stateify type -type truncatingUint16 struct { - save uint64 - load uint16 `state:"nosave"` -} - -func (t *truncatingUint16) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingUint16) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = uint64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingUint16)(nil) - -// +stateify type -type truncatingUint32 struct { - save uint64 - load uint32 `state:"nosave"` -} - -func (t *truncatingUint32) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingUint32) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = uint64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingUint32)(nil) - -// +stateify type -type truncatingInt8 struct { - save int64 - load int8 `state:"nosave"` -} - -func (t *truncatingInt8) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingInt8) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = int64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingInt8)(nil) - -// +stateify type -type truncatingInt16 struct { - save int64 - load int16 `state:"nosave"` -} - -func (t *truncatingInt16) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingInt16) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = int64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingInt16)(nil) - -// +stateify type -type truncatingInt32 struct { - save int64 - load int32 `state:"nosave"` -} - -func (t *truncatingInt32) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingInt32) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = int64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingInt32)(nil) - -// +stateify type -type truncatingFloat32 struct { - save float64 - load float32 `state:"nosave"` -} - -func (t *truncatingFloat32) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingFloat32) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = float64(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingFloat32)(nil) - -// +stateify type -type truncatingComplex64 struct { - save complex128 - load complex64 `state:"nosave"` -} - -func (t *truncatingComplex64) StateSave(m state.Sink) { - m.Save(0, &t.save) -} - -func (t *truncatingComplex64) StateLoad(m state.Source) { - m.Load(0, &t.load) - t.save = complex128(t.load) - t.load = 0 -} - -var _ state.SaverLoader = (*truncatingComplex64)(nil) diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go deleted file mode 100644 index 2b1609af0..000000000 --- a/pkg/state/tests/integer_test.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "math" - "testing" -) - -var ( - allBasicInts = []int{-1, 0, 1} - allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} - allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} - allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} - allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} - allBasicUints = []uint{0, 1} - allUintptrs = []uintptr{0, 1, ^uintptr(0)} - allUint8s = []uint8{0, 1, math.MaxUint8} - allUint16s = []uint16{0, 1, math.MaxUint16} - allUint32s = []uint32{0, 1, math.MaxUint32} - allUint64s = []uint64{0, 1, math.MaxUint64} -) - -var allInts = flatten( - allBasicInts, - allInt8s, - allInt16s, - allInt32s, - allInt64s, -) - -var allUints = flatten( - allBasicUints, - allUintptrs, - allUint8s, - allUint16s, - allUint32s, - allUint64s, -) - -func TestInt(t *testing.T) { - runTestCases(t, false, "plain", allInts) - runTestCases(t, false, "pointers", pointersTo(allInts)) - runTestCases(t, false, "interfaces", interfacesTo(allInts)) - runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts))) -} - -func TestIntTruncation(t *testing.T) { - runTestCases(t, true, "pass", []interface{}{ - truncatingInt8{save: math.MinInt8 - 1}, - truncatingInt16{save: math.MinInt16 - 1}, - truncatingInt32{save: math.MinInt32 - 1}, - truncatingInt8{save: math.MaxInt8 + 1}, - truncatingInt16{save: math.MaxInt16 + 1}, - truncatingInt32{save: math.MaxInt32 + 1}, - }) - runTestCases(t, false, "fail", []interface{}{ - truncatingInt8{save: 1}, - truncatingInt16{save: 1}, - truncatingInt32{save: 1}, - }) -} - -func TestUint(t *testing.T) { - runTestCases(t, false, "plain", allUints) - runTestCases(t, false, "pointers", pointersTo(allUints)) - runTestCases(t, false, "interfaces", interfacesTo(allUints)) - runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints))) -} - -func TestUintTruncation(t *testing.T) { - runTestCases(t, true, "pass", []interface{}{ - truncatingUint8{save: math.MaxUint8 + 1}, - truncatingUint16{save: math.MaxUint16 + 1}, - truncatingUint32{save: math.MaxUint32 + 1}, - }) - runTestCases(t, false, "fail", []interface{}{ - truncatingUint8{save: 1}, - truncatingUint16{save: 1}, - truncatingUint32{save: 1}, - }) -} diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go deleted file mode 100644 index a8350c0f3..000000000 --- a/pkg/state/tests/load.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -// +stateify savable -type genericContainer struct { - v interface{} -} - -// +stateify savable -type afterLoadStruct struct { - v int `state:"nosave"` -} - -func (a *afterLoadStruct) afterLoad() { - a.v++ -} - -// +stateify savable -type valueLoadStruct struct { - v int `state:".(int64)"` -} - -func (v *valueLoadStruct) saveV() int64 { - return int64(v.v) // Save as int64. -} - -func (v *valueLoadStruct) loadV(value int64) { - v.v = int(value) // Load as int. -} - -// +stateify savable -type cycleStruct struct { - c *cycleStruct -} - -// +stateify savable -type badCycleStruct struct { - b *badCycleStruct `state:"wait"` -} - -func (b *badCycleStruct) afterLoad() { - if b.b != b { - // This is not executable, since AfterLoad requires that the - // object and all dependencies are complete. This should cause - // a deadlock error during load. - panic("badCycleStruct.afterLoad called") - } -} diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go deleted file mode 100644 index 3c73ac391..000000000 --- a/pkg/state/tests/load_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "testing" -) - -func TestLoadHooks(t *testing.T) { - runTestCases(t, false, "load-hooks", []interface{}{ - // Root object being a struct. - afterLoadStruct{v: 1}, - valueLoadStruct{v: 1}, - genericContainer{v: &afterLoadStruct{v: 1}}, - genericContainer{v: &valueLoadStruct{v: 1}}, - sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - // Root object being a pointer. - &afterLoadStruct{v: 1}, - &valueLoadStruct{v: 1}, - &genericContainer{v: &afterLoadStruct{v: 1}}, - &genericContainer{v: &valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, - }) -} - -func TestCycles(t *testing.T) { - // cs is a single object cycle. - cs := cycleStruct{nil} - cs.c = &cs - - // cs1 and cs2 are in a two object cycle. - cs1 := cycleStruct{nil} - cs2 := cycleStruct{nil} - cs1.c = &cs2 - cs2.c = &cs1 - - runTestCases(t, false, "cycles", []interface{}{ - cs, - cs1, - }) -} - -func TestDeadlock(t *testing.T) { - // bs is a single object cycle. This does not cause deadlock because an - // object cannot wait for itself. - bs := badCycleStruct{nil} - bs.b = &bs - - runTestCases(t, false, "self", []interface{}{ - &bs, - }) - - // bs2 and bs2 are in a deadlocking cycle. - bs1 := badCycleStruct{nil} - bs2 := badCycleStruct{nil} - bs1.b = &bs2 - bs2.b = &bs1 - - runTestCases(t, true, "deadlock", []interface{}{ - &bs1, - }) -} diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go deleted file mode 100644 index db4e548f1..000000000 --- a/pkg/state/tests/map.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -// +stateify savable -type mapContainer struct { - v map[int]interface{} -} - -// +stateify savable -type mapPtrContainer struct { - v *map[int]interface{} -} - -// +stateify savable -type registeredMapStruct struct{} diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go deleted file mode 100644 index 92bf0fc01..000000000 --- a/pkg/state/tests/map_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "reflect" - "testing" -) - -var allMapPrimitives = []interface{}{ - bool(true), - int(1), - int8(1), - int16(1), - int32(1), - int64(1), - uint(1), - uintptr(1), - uint8(1), - uint16(1), - uint32(1), - uint64(1), - string(""), - registeredMapStruct{}, -} - -var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives)) - -var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives)) - -var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { - m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) - return m.Interface() -}) - -var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { - m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) - m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2))) - return m.Interface() -}) - -func TestMapAliasing(t *testing.T) { - v := make(map[int]int) - ptrToV := &v - aliases := []map[int]int{v, v} - runTestCases(t, false, "", []interface{}{ptrToV, aliases}) -} - -func TestMapsEmpty(t *testing.T) { - runTestCases(t, false, "plain", emptyMaps) - runTestCases(t, false, "pointers", pointersTo(emptyMaps)) - runTestCases(t, false, "interfaces", interfacesTo(emptyMaps)) - runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps))) -} - -func TestMapsFull(t *testing.T) { - runTestCases(t, false, "plain", fullMaps) - runTestCases(t, false, "pointers", pointersTo(fullMaps)) - runTestCases(t, false, "interfaces", interfacesTo(fullMaps)) - runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps))) -} - -func TestMapContainers(t *testing.T) { - var ( - nilMap map[int]interface{} - emptyMap = make(map[int]interface{}) - fullMap = map[int]interface{}{0: nil} - ) - runTestCases(t, false, "", []interface{}{ - mapContainer{v: nilMap}, - mapContainer{v: emptyMap}, - mapContainer{v: fullMap}, - mapPtrContainer{v: nil}, - mapPtrContainer{v: &nilMap}, - mapPtrContainer{v: &emptyMap}, - mapPtrContainer{v: &fullMap}, - }) -} diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go deleted file mode 100644 index 074d86315..000000000 --- a/pkg/state/tests/register.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -// +stateify savable -type alreadyRegisteredStruct struct{} - -// +stateify savable -type alreadyRegisteredOther int diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go deleted file mode 100644 index 2199d6b01..000000000 --- a/pkg/state/tests/register_test.go +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build race -// +build race - -package tests - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/state" -) - -// faker calls itself whatever is in the name field. -type faker struct { - Name string - Fields []string -} - -func (f *faker) StateTypeName() string { - return f.Name -} - -func (f *faker) StateFields() []string { - return f.Fields -} - -// fakerWithSaverLoader has all it needs. -type fakerWithSaverLoader struct { - faker -} - -func (f *fakerWithSaverLoader) StateSave(m state.Sink) {} - -func (f *fakerWithSaverLoader) StateLoad(m state.Source) {} - -// fakerOther calls itself .. uh, itself? -type fakerOther string - -func (f *fakerOther) StateTypeName() string { - return string(*f) -} - -func (f *fakerOther) StateFields() []string { - return nil -} - -func newFakerOther(name string) *fakerOther { - f := fakerOther(name) - return &f -} - -// fakerOtherBadFields returns non-nil fields. -type fakerOtherBadFields string - -func (f *fakerOtherBadFields) StateTypeName() string { - return string(*f) -} - -func (f *fakerOtherBadFields) StateFields() []string { - return []string{string(*f)} -} - -func newFakerOtherBadFields(name string) *fakerOtherBadFields { - f := fakerOtherBadFields(name) - return &f -} - -// fakerOtherSaverLoader implements SaverLoader methods. -type fakerOtherSaverLoader string - -func (f *fakerOtherSaverLoader) StateTypeName() string { - return string(*f) -} - -func (f *fakerOtherSaverLoader) StateFields() []string { - return nil -} - -func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {} - -func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {} - -func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader { - f := fakerOtherSaverLoader(name) - return &f -} - -func TestRegisterPrimitives(t *testing.T) { - for _, typeName := range []string{ - "int", - "int8", - "int16", - "int32", - "int64", - "uint", - "uintptr", - "uint8", - "uint16", - "uint32", - "uint64", - "float32", - "float64", - "complex64", - "complex128", - "string", - } { - t.Run("struct/"+typeName, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Registering type %q did not panic", typeName) - } - }() - state.Register(&faker{ - Name: typeName, - }) - }) - t.Run("other/"+typeName, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Registering type %q did not panic", typeName) - } - }() - state.Register(newFakerOther(typeName)) - }) - } -} - -func TestRegisterBad(t *testing.T) { - const ( - goodName = "foo" - firstField = "a" - secondField = "b" - ) - for name, object := range map[string]state.Type{ - "non-struct-with-fields": newFakerOtherBadFields(goodName), - "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName), - "struct-without-saverloader": &faker{Name: goodName}, - "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()), - "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()), - "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}}, - "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}}, - "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}}, - "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}}, - "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}}, - "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}}, - } { - t.Run(name, func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Registering object %#v did not panic", object) - } - }() - state.Register(object) - }) - - } -} - -func TestRegisterTypeOnlyStruct(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Register did not panic") - } - }() - state.Register((*typeOnlyEmptyStruct)(nil)) -} diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go deleted file mode 100644 index 44f5a562c..000000000 --- a/pkg/state/tests/string_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "testing" -) - -const nonEmptyString = "hello world" - -var allStrings = []string{ - "", - nonEmptyString, - "\\0", -} - -func TestString(t *testing.T) { - runTestCases(t, false, "plain", flatten(allStrings)) - runTestCases(t, false, "pointers", pointersTo(flatten(allStrings))) - runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings))) - runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings)))) -} diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go deleted file mode 100644 index 69143d194..000000000 --- a/pkg/state/tests/struct.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -type unregisteredEmptyStruct struct{} - -// typeOnlyEmptyStruct just implements the state.Type interface. -type typeOnlyEmptyStruct struct{} - -func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" } - -func (*typeOnlyEmptyStruct) StateFields() []string { return nil } - -// +stateify savable -type savableEmptyStruct struct{} - -// +stateify savable -type emptyStructPointer struct { - nothing *struct{} -} - -// +stateify savable -type outerSame struct { - inner inner -} - -// +stateify savable -type outerFieldFirst struct { - inner inner - v int64 -} - -// +stateify savable -type outerFieldSecond struct { - v int64 - inner inner -} - -// +stateify savable -type outerArray struct { - inner [2]inner -} - -// +stateify savable -type outerSlice struct { - inner []inner -} - -// +stateify savable -type inner struct { - v int64 -} - -// +stateify savable -type outerFieldValue struct { - inner innerFieldValue -} - -// +stateify savable -type innerFieldValue struct { - v int64 `state:".(*savedFieldValue)"` -} - -// +stateify savable -type savedFieldValue struct { - v int64 -} - -func (ifv *innerFieldValue) saveV() *savedFieldValue { - return &savedFieldValue{ifv.v} -} - -func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) { - ifv.v = sfv.v -} - -// +stateify savable -type system struct { - v1 interface{} - v2 interface{} -} - -// +stateify savable -type system3 struct { - v1 interface{} - v2 interface{} - v3 interface{} -} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go deleted file mode 100644 index 9826f1ee9..000000000 --- a/pkg/state/tests/struct_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tests - -import ( - "math/rand" - "testing" -) - -func TestEmptyStruct(t *testing.T) { - runTestCases(t, false, "plain", []interface{}{ - unregisteredEmptyStruct{}, - typeOnlyEmptyStruct{}, - savableEmptyStruct{}, - }) - runTestCases(t, false, "pointers", pointersTo([]interface{}{ - unregisteredEmptyStruct{}, - typeOnlyEmptyStruct{}, - savableEmptyStruct{}, - })) - runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{ - // Only registered types can be dispatched via interfaces. All - // other types should fail, even if it is the empty struct. - savableEmptyStruct{}, - })) - runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{ - unregisteredEmptyStruct{}, - typeOnlyEmptyStruct{}, - })) - runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{ - savableEmptyStruct{}, - }))) - runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{ - unregisteredEmptyStruct{}, - typeOnlyEmptyStruct{}, - }))) - - // Ensuring empty struct aliasing works. - es := emptyStructPointer{new(struct{})} - runTestCases(t, false, "empty-struct-pointers", []interface{}{ - emptyStructPointer{}, - es, - []emptyStructPointer{es, es}, // Same pointer. - }) -} - -func TestEmbeddedPointers(t *testing.T) { - // Give each int64 a random value to prevent Go from using - // runtime.staticuint64s, which confounds tests for struct duplication. - magic := func() int64 { - for { - n := rand.Int63() - if n < 0 || n > 255 { - return n - } - } - } - - ofs := outerSame{inner{magic()}} - of1 := outerFieldFirst{inner{magic()}, magic()} - of2 := outerFieldSecond{magic(), inner{magic()}} - oa := outerArray{[2]inner{{magic()}, {magic()}}} - osl := outerSlice{oa.inner[:]} - ofv := outerFieldValue{innerFieldValue{magic()}} - - runTestCases(t, false, "embedded-pointers", []interface{}{ - system{&ofs, &ofs.inner}, - system{&ofs.inner, &ofs}, - system{&of1, &of1.inner}, - system{&of1.inner, &of1}, - system{&of2, &of2.inner}, - system{&of2.inner, &of2}, - system{&oa, &oa.inner[0]}, - system{&oa, &oa.inner[1]}, - system{&oa.inner[0], &oa}, - system{&oa.inner[1], &oa}, - system3{&oa, &oa.inner[0], &oa.inner[1]}, - system3{&oa, &oa.inner[1], &oa.inner[0]}, - system3{&oa.inner[0], &oa, &oa.inner[1]}, - system3{&oa.inner[1], &oa, &oa.inner[0]}, - system3{&oa.inner[0], &oa.inner[1], &oa}, - system3{&oa.inner[1], &oa.inner[0], &oa}, - system{&oa, &osl}, - system{&osl, &oa}, - system{&ofv, &ofv.inner}, - system{&ofv.inner, &ofv}, - }) -} diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go deleted file mode 100644 index 435a0e9db..000000000 --- a/pkg/state/tests/tests.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tests tests the state packages. -package tests - -import ( - "bytes" - "context" - "fmt" - "math" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/state" - "gvisor.dev/gvisor/pkg/state/pretty" -) - -// discard is an implementation of wire.Writer. -type discard struct{} - -// Write implements wire.Writer.Write. -func (discard) Write(p []byte) (int, error) { return len(p), nil } - -// WriteByte implements wire.Writer.WriteByte. -func (discard) WriteByte(byte) error { return nil } - -// checkEqual checks if two objects are equal. -// -// N.B. This only handles one level of dereferences for NaN. Otherwise we -// would need to fork the entire implementation of reflect.DeepEqual. -func checkEqual(root, loadedValue interface{}) bool { - if reflect.DeepEqual(root, loadedValue) { - return true - } - - // NaN is not equal to itself. We handle the case of raw floating point - // primitives here, but don't handle this case nested. - rf32, ok1 := root.(float32) - lf32, ok2 := loadedValue.(float32) - if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) { - return true - } - rf64, ok1 := root.(float64) - lf64, ok2 := loadedValue.(float64) - if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) { - return true - } - - // Same real for complex numbers. - rc64, ok1 := root.(complex64) - lc64, ok2 := root.(complex64) - if ok1 && ok2 { - return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64)) - } - rc128, ok1 := root.(complex128) - lc128, ok2 := root.(complex128) - if ok1 && ok2 { - return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128)) - } - - return false -} - -// runTestCases runs a test for each object in objects. -func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) { - t.Helper() - for i, root := range objects { - t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) { - t.Logf("Original object:\n%#v", root) - - // Save the passed object. - saveBuffer := &bytes.Buffer{} - saveObjectPtr := reflect.New(reflect.TypeOf(root)) - saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface()) - if err != nil { - if shouldFail { - return - } - t.Fatalf("Save failed unexpectedly: %v", err) - } - - // Dump the serialized proto to aid with debugging. - var ppBuf bytes.Buffer - t.Logf("Raw state:\n%v", saveBuffer.Bytes()) - if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil { - // We don't count this as a test failure if we - // have shouldFail set, but we will count as a - // failure if we were not expecting to fail. - if !shouldFail { - t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err) - } - } - if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil { - // See above. - if !shouldFail { - t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err) - } - } - t.Logf("Encoded state:\n%s", ppBuf.String()) - t.Logf("Save stats:\n%s", saveStats.String()) - - // Load a new copy of the object. - loadObjectPtr := reflect.New(reflect.TypeOf(root)) - loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface()) - if err != nil { - if shouldFail { - return - } - t.Fatalf("Load failed unexpectedly: %v", err) - } - - // Compare the values. - loadedValue := loadObjectPtr.Elem().Interface() - if !checkEqual(root, loadedValue) { - if shouldFail { - return - } - t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue) - } - - // Everything went okay. Is that good? - if shouldFail { - t.Fatalf("This test was expected to fail, but didn't.") - } - t.Logf("Load stats:\n%s", loadStats.String()) - - // Truncate half the bytes in the byte stream, - // and ensure that we can't restore. Then - // truncate only the final byte and ensure that - // we can't restore. - l := saveBuffer.Len() - halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2]) - if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil { - t.Errorf("Load with half bytes succeeded unexpectedly.") - } - missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1]) - if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil { - t.Errorf("Load with missing byte succeeded unexpectedly.") - } - }) - } -} - -// convert converts the slice to an []interface{}. -func convert(v interface{}) (r []interface{}) { - s := reflect.ValueOf(v) // Must be slice. - for i := 0; i < s.Len(); i++ { - r = append(r, s.Index(i).Interface()) - } - return r -} - -// flatten flattens multiple slices. -func flatten(vs ...interface{}) (r []interface{}) { - for _, v := range vs { - r = append(r, convert(v)...) - } - return r -} - -// filter maps from one slice to another. -func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) { - s := reflect.ValueOf(vs) - for i := 0; i < s.Len(); i++ { - v, ok := fn(s.Index(i).Interface()) - if ok { - r = append(r, v) - } - } - return r -} - -// combine combines objects in two slices as specified. -func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) { - s1 := reflect.ValueOf(v1) - s2 := reflect.ValueOf(v2) - for i := 0; i < s1.Len(); i++ { - for j := 0; j < s2.Len(); j++ { - // Combine using the given function. - r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface())) - } - } - return r -} - -// pointersTo is a filter function that returns pointers. -func pointersTo(vs interface{}) []interface{} { - return filter(vs, func(o interface{}) (interface{}, bool) { - v := reflect.New(reflect.TypeOf(o)) - v.Elem().Set(reflect.ValueOf(o)) - return v.Interface(), true - }) -} - -// interfacesTo is a filter function that returns interface objects. -func interfacesTo(vs interface{}) []interface{} { - return filter(vs, func(o interface{}) (interface{}, bool) { - var v [1]interface{} - v[0] = o - return v, true - }) -} diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD deleted file mode 100644 index 311b93dcb..000000000 --- a/pkg/state/wire/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "wire", - srcs = ["wire.go"], - marshal = False, - stateify = False, - visibility = ["//:sandbox"], - deps = ["//pkg/gohacks"], -) diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD deleted file mode 100644 index 73791b456..000000000 --- a/pkg/sync/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -go_library( - name = "sync", - srcs = [ - "aliases.go", - "checklocks_off_unsafe.go", - "checklocks_on_unsafe.go", - "gate_unsafe.go", - "goyield_go113_unsafe.go", - "goyield_unsafe.go", - "mutex_unsafe.go", - "nocopy.go", - "norace_unsafe.go", - "race_amd64.s", - "race_arm64.s", - "race_unsafe.go", - "runtime_unsafe.go", - "rwmutex_unsafe.go", - "seqcount.go", - "sync.go", - ], - marshal = False, - stateify = False, - visibility = ["//:sandbox"], - deps = [ - "//pkg/gohacks", - "//pkg/goid", - ], -) - -go_test( - name = "sync_test", - size = "small", - srcs = [ - "gate_test.go", - "mutex_test.go", - "rwmutex_test.go", - "seqcount_test.go", - ], - library = ":sync", -) diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sync/README.md b/pkg/sync/README.md deleted file mode 100644 index be1a01f08..000000000 --- a/pkg/sync/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# sync - -This package provides additional synchronization primitives not provided by the -Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package from go1.10. diff --git a/pkg/sync/atomicptr/BUILD b/pkg/sync/atomicptr/BUILD deleted file mode 100644 index a6a7f01ac..000000000 --- a/pkg/sync/atomicptr/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_template( - name = "generic_atomicptr", - srcs = ["generic_atomicptr_unsafe.go"], - types = [ - "Value", - ], - visibility = ["//:sandbox"], -) - -go_template_instance( - name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", - package = "atomicptr", - suffix = "Int", - template = ":generic_atomicptr", - types = { - "Value": "int", - }, -) - -go_library( - name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], -) - -go_test( - name = "atomicptr_test", - size = "small", - srcs = ["atomicptr_test.go"], - library = ":atomicptr", -) diff --git a/pkg/sync/atomicptr/atomicptr_test.go b/pkg/sync/atomicptr/atomicptr_test.go deleted file mode 100644 index 8fdc5112e..000000000 --- a/pkg/sync/atomicptr/atomicptr_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package atomicptr - -import ( - "testing" -) - -func newInt(val int) *int { - return &val -} - -func TestAtomicPtr(t *testing.T) { - var p AtomicPtrInt - if got := p.Load(); got != nil { - t.Errorf("initial value is %p (%v), wanted nil", got, got) - } - want := newInt(42) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } - want = newInt(100) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } -} diff --git a/pkg/sync/atomicptrmap/BUILD b/pkg/sync/atomicptrmap/BUILD deleted file mode 100644 index b0e218c79..000000000 --- a/pkg/sync/atomicptrmap/BUILD +++ /dev/null @@ -1,76 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package( - default_visibility = ["//visibility:private"], - licenses = ["notice"], -) - -go_template( - name = "generic_atomicptrmap", - srcs = ["generic_atomicptrmap_unsafe.go"], - opt_consts = [ - "ShardOrder", - ], - opt_types = [ - "Hasher", - ], - types = [ - "Key", - "Value", - ], - deps = [ - "//pkg/gohacks", - "//pkg/sync", - ], -) - -go_template_instance( - name = "test_atomicptrmap", - out = "test_atomicptrmap_unsafe.go", - package = "atomicptrmap", - prefix = "test", - template = ":generic_atomicptrmap", - types = { - "Key": "int64", - "Value": "testValue", - }, -) - -go_template_instance( - name = "test_atomicptrmap_sharded", - out = "test_atomicptrmap_sharded_unsafe.go", - consts = { - "ShardOrder": "4", - }, - package = "atomicptrmap", - prefix = "test", - suffix = "Sharded", - template = ":generic_atomicptrmap", - types = { - "Key": "int64", - "Value": "testValue", - }, -) - -go_library( - name = "atomicptrmap", - testonly = 1, - srcs = [ - "atomicptrmap.go", - "test_atomicptrmap_sharded_unsafe.go", - "test_atomicptrmap_unsafe.go", - ], - deps = [ - "//pkg/gohacks", - "//pkg/sync", - ], -) - -go_test( - name = "atomicptrmap_test", - size = "small", - srcs = ["atomicptrmap_test.go"], - library = ":atomicptrmap", - deps = ["//pkg/sync"], -) diff --git a/pkg/sync/atomicptrmap/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go deleted file mode 100644 index 867821ce9..000000000 --- a/pkg/sync/atomicptrmap/atomicptrmap.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package atomicptrmap instantiates generic_atomicptrmap for testing. -package atomicptrmap - -type testValue struct { - val int -} diff --git a/pkg/sync/atomicptrmap/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go deleted file mode 100644 index 75a9997ef..000000000 --- a/pkg/sync/atomicptrmap/atomicptrmap_test.go +++ /dev/null @@ -1,635 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package atomicptrmap - -import ( - "context" - "fmt" - "math/rand" - "reflect" - "runtime" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestConsistencyWithGoMap(t *testing.T) { - const maxKey = 16 - var vals [4]*testValue - for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { - vals[i] = new(testValue) - } - var ( - m = make(map[int64]*testValue) - apm testAtomicPtrMap - ) - for i := 0; i < 100000; i++ { - // Apply a random operation to both m and apm and expect them to have - // the same result. Bias toward CompareAndSwap, which has the most - // cases; bias away from Range and RangeRepeatable, which are - // relatively expensive. - switch rand.Intn(10) { - case 0, 1: // Load - key := rand.Int63n(maxKey) - want := m[key] - got := apm.Load(key) - t.Logf("Load(%d) = %p", key, got) - if got != want { - t.Fatalf("got %p, wanted %p", got, want) - } - case 2, 3: // Swap - key := rand.Int63n(maxKey) - val := vals[rand.Intn(len(vals))] - want := m[key] - if val != nil { - m[key] = val - } else { - delete(m, key) - } - got := apm.Swap(key, val) - t.Logf("Swap(%d, %p) = %p", key, val, got) - if got != want { - t.Fatalf("got %p, wanted %p", got, want) - } - case 4, 5, 6, 7: // CompareAndSwap - key := rand.Int63n(maxKey) - oldVal := vals[rand.Intn(len(vals))] - newVal := vals[rand.Intn(len(vals))] - want := m[key] - if want == oldVal { - if newVal != nil { - m[key] = newVal - } else { - delete(m, key) - } - } - got := apm.CompareAndSwap(key, oldVal, newVal) - t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got) - if got != want { - t.Fatalf("got %p, wanted %p", got, want) - } - case 8: // Range - got := make(map[int64]*testValue) - var ( - haveDup = false - dup int64 - ) - apm.Range(func(key int64, val *testValue) bool { - if _, ok := got[key]; ok && !haveDup { - haveDup = true - dup = key - } - got[key] = val - return true - }) - t.Logf("Range() = %v", got) - if !reflect.DeepEqual(got, m) { - t.Fatalf("got %v, wanted %v", got, m) - } - if haveDup { - t.Fatalf("got duplicate key %d", dup) - } - case 9: // RangeRepeatable - got := make(map[int64]*testValue) - apm.RangeRepeatable(func(key int64, val *testValue) bool { - got[key] = val - return true - }) - t.Logf("RangeRepeatable() = %v", got) - if !reflect.DeepEqual(got, m) { - t.Fatalf("got %v, wanted %v", got, m) - } - } - } -} - -func TestConcurrentHeterogeneous(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - var ( - apm testAtomicPtrMap - wg sync.WaitGroup - ) - defer func() { - cancel() - wg.Wait() - }() - - possibleKeyValuePairs := make(map[int64]map[*testValue]struct{}) - addKeyValuePair := func(key int64, val *testValue) { - values := possibleKeyValuePairs[key] - if values == nil { - values = make(map[*testValue]struct{}) - possibleKeyValuePairs[key] = values - } - values[val] = struct{}{} - } - - const numValuesPerKey = 4 - - // These goroutines use keys not used by any other goroutine. - const numPrivateKeys = 3 - for i := 0; i < numPrivateKeys; i++ { - key := int64(i) - var vals [numValuesPerKey]*testValue - for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { - val := new(testValue) - vals[i] = val - addKeyValuePair(key, val) - } - wg.Add(1) - go func() { - defer wg.Done() - r := rand.New(rand.NewSource(rand.Int63())) - var stored *testValue - for ctx.Err() == nil { - switch r.Intn(4) { - case 0: - got := apm.Load(key) - if got != stored { - t.Errorf("Load(%d): got %p, wanted %p", key, got, stored) - return - } - case 1: - val := vals[r.Intn(len(vals))] - want := stored - stored = val - got := apm.Swap(key, val) - if got != want { - t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want) - return - } - case 2, 3: - oldVal := vals[r.Intn(len(vals))] - newVal := vals[r.Intn(len(vals))] - want := stored - if stored == oldVal { - stored = newVal - } - got := apm.CompareAndSwap(key, oldVal, newVal) - if got != want { - t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want) - return - } - } - } - }() - } - - // These goroutines share a small set of keys. - const numSharedKeys = 2 - var ( - sharedKeys [numSharedKeys]int64 - sharedValues = make(map[int64][]*testValue) - sharedValuesSet = make(map[int64]map[*testValue]struct{}) - ) - for i := range sharedKeys { - key := int64(numPrivateKeys + i) - sharedKeys[i] = key - vals := make([]*testValue, numValuesPerKey) - valsSet := make(map[*testValue]struct{}) - for j := range vals { - val := new(testValue) - vals[j] = val - valsSet[val] = struct{}{} - addKeyValuePair(key, val) - } - sharedValues[key] = vals - sharedValuesSet[key] = valsSet - } - randSharedValue := func(r *rand.Rand, key int64) *testValue { - vals := sharedValues[key] - return vals[r.Intn(len(vals))] - } - for i := 0; i < 3; i++ { - wg.Add(1) - go func() { - defer wg.Done() - r := rand.New(rand.NewSource(rand.Int63())) - for ctx.Err() == nil { - keyIndex := r.Intn(len(sharedKeys)) - key := sharedKeys[keyIndex] - var ( - op string - got *testValue - ) - switch r.Intn(4) { - case 0: - op = "Load" - got = apm.Load(key) - case 1: - op = "Swap" - got = apm.Swap(key, randSharedValue(r, key)) - case 2, 3: - op = "CompareAndSwap" - got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key)) - } - if got != nil { - valsSet := sharedValuesSet[key] - if _, ok := valsSet[got]; !ok { - t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet) - return - } - } - } - }() - } - - // This goroutine repeatedly searches for unused keys. - wg.Add(1) - go func() { - defer wg.Done() - r := rand.New(rand.NewSource(rand.Int63())) - for ctx.Err() == nil { - key := -1 - r.Int63() - if got := apm.Load(key); got != nil { - t.Errorf("Load(%d): got %p, wanted nil", key, got) - } - } - }() - - // This goroutine repeatedly calls RangeRepeatable() and checks that each - // key corresponds to an expected value. - wg.Add(1) - go func() { - defer wg.Done() - abort := false - for !abort && ctx.Err() == nil { - apm.RangeRepeatable(func(key int64, val *testValue) bool { - values, ok := possibleKeyValuePairs[key] - if !ok { - t.Errorf("RangeRepeatable: got invalid key %d", key) - abort = true - return false - } - if _, ok := values[val]; !ok { - t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values) - abort = true - return false - } - return true - }) - } - }() - - // Finally, the main goroutine spins for the length of the test calling - // Range() and checking that each key that it observes is unique and - // corresponds to an expected value. - seenKeys := make(map[int64]struct{}) - const testDuration = 5 * time.Second - end := time.Now().Add(testDuration) - abort := false - for time.Now().Before(end) { - apm.Range(func(key int64, val *testValue) bool { - values, ok := possibleKeyValuePairs[key] - if !ok { - t.Errorf("Range: got invalid key %d", key) - abort = true - return false - } - if _, ok := values[val]; !ok { - t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values) - abort = true - return false - } - if _, ok := seenKeys[key]; ok { - t.Errorf("Range: got duplicate key %d", key) - abort = true - return false - } - seenKeys[key] = struct{}{} - return true - }) - if abort { - break - } - for k := range seenKeys { - delete(seenKeys, k) - } - } -} - -type benchmarkableMap interface { - Load(key int64) *testValue - Store(key int64, val *testValue) - LoadOrStore(key int64, val *testValue) (*testValue, bool) - Delete(key int64) -} - -// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map. -type rwMutexMap struct { - mu sync.RWMutex - m map[int64]*testValue -} - -func (m *rwMutexMap) Load(key int64) *testValue { - m.mu.RLock() - defer m.mu.RUnlock() - return m.m[key] -} - -func (m *rwMutexMap) Store(key int64, val *testValue) { - m.mu.Lock() - defer m.mu.Unlock() - if m.m == nil { - m.m = make(map[int64]*testValue) - } - m.m[key] = val -} - -func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { - m.mu.Lock() - defer m.mu.Unlock() - if m.m == nil { - m.m = make(map[int64]*testValue) - } - if oldVal, ok := m.m[key]; ok { - return oldVal, true - } - m.m[key] = val - return val, false -} - -func (m *rwMutexMap) Delete(key int64) { - m.mu.Lock() - defer m.mu.Unlock() - delete(m.m, key) -} - -// syncMap implements benchmarkableMap for a sync.Map. -type syncMap struct { - m sync.Map -} - -func (m *syncMap) Load(key int64) *testValue { - val, ok := m.m.Load(key) - if !ok { - return nil - } - return val.(*testValue) -} - -func (m *syncMap) Store(key int64, val *testValue) { - m.m.Store(key, val) -} - -func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { - actual, loaded := m.m.LoadOrStore(key, val) - return actual.(*testValue), loaded -} - -func (m *syncMap) Delete(key int64) { - m.m.Delete(key) -} - -// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap. -type benchmarkableAtomicPtrMap struct { - m testAtomicPtrMap -} - -func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue { - return m.m.Load(key) -} - -func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) { - m.m.Store(key, val) -} - -func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { - if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { - return prev, true - } - return val, false -} - -func (m *benchmarkableAtomicPtrMap) Delete(key int64) { - m.m.Store(key, nil) -} - -// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded. -type benchmarkableAtomicPtrMapSharded struct { - m testAtomicPtrMapSharded -} - -func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue { - return m.m.Load(key) -} - -func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) { - m.m.Store(key, val) -} - -func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) { - if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { - return prev, true - } - return val, false -} - -func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) { - m.m.Store(key, nil) -} - -var mapImpls = [...]struct { - name string - ctor func() benchmarkableMap -}{ - { - name: "RWMutexMap", - ctor: func() benchmarkableMap { - return new(rwMutexMap) - }, - }, - { - name: "SyncMap", - ctor: func() benchmarkableMap { - return new(syncMap) - }, - }, - { - name: "AtomicPtrMap", - ctor: func() benchmarkableMap { - return new(benchmarkableAtomicPtrMap) - }, - }, - { - name: "AtomicPtrMapSharded", - ctor: func() benchmarkableMap { - return new(benchmarkableAtomicPtrMapSharded) - }, - }, -} - -func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { - m := mapCtor() - val := &testValue{} - for i := 0; i < b.N; i++ { - m.Store(int64(i), val) - } - for i := 0; i < b.N; i++ { - m.Delete(int64(i)) - } -} - -func BenchmarkStoreDelete(b *testing.B) { - for _, mapImpl := range mapImpls { - b.Run(mapImpl.name, func(b *testing.B) { - benchmarkStoreDelete(b, mapImpl.ctor) - }) - } -} - -func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { - m := mapCtor() - val := &testValue{} - for i := 0; i < b.N; i++ { - m.LoadOrStore(int64(i), val) - } - for i := 0; i < b.N; i++ { - m.Delete(int64(i)) - } -} - -func BenchmarkLoadOrStoreDelete(b *testing.B) { - for _, mapImpl := range mapImpls { - b.Run(mapImpl.name, func(b *testing.B) { - benchmarkLoadOrStoreDelete(b, mapImpl.ctor) - }) - } -} - -func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) { - m := mapCtor() - val := &testValue{} - for i := 0; i < b.N; i++ { - m.Store(int64(i), val) - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - m.Load(int64(i)) - } -} - -func BenchmarkLookupPositive(b *testing.B) { - for _, mapImpl := range mapImpls { - b.Run(mapImpl.name, func(b *testing.B) { - benchmarkLookupPositive(b, mapImpl.ctor) - }) - } -} - -func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) { - m := mapCtor() - val := &testValue{} - for i := 0; i < b.N; i++ { - m.Store(int64(i), val) - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - m.Load(int64(-1 - i)) - } -} - -func BenchmarkLookupNegative(b *testing.B) { - for _, mapImpl := range mapImpls { - b.Run(mapImpl.name, func(b *testing.B) { - benchmarkLookupNegative(b, mapImpl.ctor) - }) - } -} - -type benchmarkConcurrentOptions struct { - // loadsPerMutationPair is the number of map lookups between each - // insertion/deletion pair. - loadsPerMutationPair int - - // If changeKeys is true, the keys used by each goroutine change between - // iterations of the test. - changeKeys bool -} - -func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) { - var ( - started sync.WaitGroup - workers sync.WaitGroup - ) - started.Add(1) - - m := mapCtor() - val := &testValue{} - // Insert a large number of unused elements into the map so that used - // elements are distributed throughout memory. - for i := 0; i < 10000; i++ { - m.Store(int64(-1-i), val) - } - // n := ceil(b.N / (opts.loadsPerMutationPair + 2)) - n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2) - for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ { - workerID := i - workers.Add(1) - go func() { - defer workers.Done() - started.Wait() - for i := 0; i < n; i++ { - var key int64 - if opts.changeKeys { - key = int64(workerID*n + i) - } else { - key = int64(workerID) - } - m.LoadOrStore(key, val) - for j := 0; j < opts.loadsPerMutationPair; j++ { - m.Load(key) - } - m.Delete(key) - } - }() - } - - b.ResetTimer() - started.Done() - workers.Wait() -} - -func BenchmarkConcurrent(b *testing.B) { - changeKeysChoices := [...]struct { - name string - val bool - }{ - {"FixedKeys", false}, - {"ChangingKeys", true}, - } - writePcts := [...]struct { - name string - loadsPerMutationPair int - }{ - {"1PercentWrites", 198}, - {"10PercentWrites", 18}, - {"50PercentWrites", 2}, - } - for _, changeKeys := range changeKeysChoices { - for _, writePct := range writePcts { - for _, mapImpl := range mapImpls { - name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name) - b.Run(name, func(b *testing.B) { - benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{ - loadsPerMutationPair: writePct.loadsPerMutationPair, - changeKeys: changeKeys.val, - }) - }) - } - } - } -} diff --git a/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go deleted file mode 100644 index 3e98cb309..000000000 --- a/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go +++ /dev/null @@ -1,500 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package atomicptrmap doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package atomicptrmap - -import ( - "sync/atomic" - "unsafe" - - "gvisor.dev/gvisor/pkg/gohacks" - "gvisor.dev/gvisor/pkg/sync" -) - -// Key is a required type parameter. -type Key struct{} - -// Value is a required type parameter. -type Value struct{} - -const ( - // ShardOrder is an optional parameter specifying the base-2 log of the - // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce - // unnecessary synchronization between unrelated concurrent operations, - // improving performance for write-heavy workloads, but increase memory - // usage for small maps. - ShardOrder = 0 -) - -// Hasher is an optional type parameter. If Hasher is provided, it must define -// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps. -type Hasher struct { - defaultHasher -} - -// defaultHasher is the default Hasher. This indirection exists because -// defaultHasher must exist even if a custom Hasher is provided, to prevent the -// Go compiler from complaining about defaultHasher's unused imports. -type defaultHasher struct { - fn func(unsafe.Pointer, uintptr) uintptr - seed uintptr -} - -// Init initializes the Hasher. -func (h *defaultHasher) Init() { - h.fn = sync.MapKeyHasher(map[Key]*Value(nil)) - h.seed = sync.RandUintptr() -} - -// Hash returns the hash value for the given Key. -func (h *defaultHasher) Hash(key Key) uintptr { - return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed) -} - -var hasher Hasher - -func init() { - hasher.Init() -} - -// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are -// safe for concurrent use from multiple goroutines without additional -// synchronization. -// -// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for -// use. AtomicPtrMaps must not be copied after first use. -// -// sync.Map may be faster than AtomicPtrMap if most operations on the map are -// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in -// other circumstances. -type AtomicPtrMap struct { - // AtomicPtrMap is implemented as a hash table with the following - // properties: - // - // * Collisions are resolved with quadratic probing. Of the two major - // alternatives, Robin Hood linear probing makes it difficult for writers - // to execute in parallel, and bucketing is less effective in Go due to - // lack of SIMD. - // - // * The table is optionally divided into shards indexed by hash to further - // reduce unnecessary synchronization. - - shards [1 << ShardOrder]apmShard -} - -func (m *AtomicPtrMap) shard(hash uintptr) *apmShard { - // Go defines right shifts >= width of shifted unsigned operand as 0, so - // this is correct even if ShardOrder is 0 (although nogo complains because - // nogo is dumb). - const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder - index := hash >> indexLSB - return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{})))) -} - -type apmShard struct { - apmShardMutationData - _ [apmShardMutationDataPadding]byte - apmShardLookupData - _ [apmShardLookupDataPadding]byte -} - -type apmShardMutationData struct { - dirtyMu sync.Mutex // serializes slot transitions out of empty - dirty uintptr // # slots with val != nil - count uintptr // # slots with val != nil and val != tombstone() - rehashMu sync.Mutex // serializes rehashing -} - -type apmShardLookupData struct { - seq sync.SeqCount // allows atomic reads of slots+mask - slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq - mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq -} - -const ( - cacheLineBytes = 64 - // Cache line padding is enabled if sharding is. - apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise - // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) % - // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes). - apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1) - apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding - apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1) - apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding - - // These define fractional thresholds for when apmShard.rehash() is called - // (i.e. the load factor) and when it rehases to a larger table - // respectively. They are chosen such that the rehash threshold = the - // expansion threshold + 1/2, so that when reuse of deleted slots is rare - // or non-existent, rehashing occurs after the insertion of at least 1/2 - // the table's size in new entries, which is acceptably infrequent. - apmRehashThresholdNum = 2 - apmRehashThresholdDen = 3 - apmExpansionThresholdNum = 1 - apmExpansionThresholdDen = 6 -) - -type apmSlot struct { - // slot states are indicated by val: - // - // * Empty: val == nil; key is meaningless. May transition to full or - // evacuated with dirtyMu locked. - // - // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val - // is the Value mapped to key. May transition to deleted or evacuated. - // - // * Deleted: val == tombstone(); key is still immutable. key is mapped to - // no Value. May transition to full or evacuated. - // - // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on - // slots that have already been moved, requiring readers to wait for - // rehashing to complete and use the new table. Terminal state. - // - // Note that once val is non-nil, it cannot become nil again. That is, the - // transition from empty to non-empty is irreversible for a given slot; - // the only way to create more empty slots is by rehashing. - val unsafe.Pointer - key Key -} - -func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot { - return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{}))) -} - -var tombstoneObj byte - -func tombstone() unsafe.Pointer { - return unsafe.Pointer(&tombstoneObj) -} - -var evacuatedObj byte - -func evacuated() unsafe.Pointer { - return unsafe.Pointer(&evacuatedObj) -} - -// Load returns the Value stored in m for key. -func (m *AtomicPtrMap) Load(key Key) *Value { - hash := hasher.Hash(key) - shard := m.shard(hash) - -retry: - epoch := shard.seq.BeginRead() - slots := atomic.LoadPointer(&shard.slots) - mask := atomic.LoadUintptr(&shard.mask) - if !shard.seq.ReadOk(epoch) { - goto retry - } - if slots == nil { - return nil - } - - i := hash & mask - inc := uintptr(1) - for { - slot := apmSlotAt(slots, i) - slotVal := atomic.LoadPointer(&slot.val) - if slotVal == nil { - // Empty slot; end of probe sequence. - return nil - } - if slotVal == evacuated() { - // Racing with rehashing. - goto retry - } - if slot.key == key { - if slotVal == tombstone() { - return nil - } - return (*Value)(slotVal) - } - i = (i + inc) & mask - inc++ - } -} - -// Store stores the Value val for key. -func (m *AtomicPtrMap) Store(key Key, val *Value) { - m.maybeCompareAndSwap(key, false, nil, val) -} - -// Swap stores the Value val for key and returns the previously-mapped Value. -func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value { - return m.maybeCompareAndSwap(key, false, nil, val) -} - -// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it -// stores the Value newVal for key. CompareAndSwap returns the previous Value -// stored for key, whether or not it stores newVal. -func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value { - return m.maybeCompareAndSwap(key, true, oldVal, newVal) -} - -func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value { - hash := hasher.Hash(key) - shard := m.shard(hash) - oldVal := tombstone() - if typedOldVal != nil { - oldVal = unsafe.Pointer(typedOldVal) - } - newVal := tombstone() - if typedNewVal != nil { - newVal = unsafe.Pointer(typedNewVal) - } - -retry: - epoch := shard.seq.BeginRead() - slots := atomic.LoadPointer(&shard.slots) - mask := atomic.LoadUintptr(&shard.mask) - if !shard.seq.ReadOk(epoch) { - goto retry - } - if slots == nil { - if (compare && oldVal != tombstone()) || newVal == tombstone() { - return nil - } - // Need to allocate a table before insertion. - shard.rehash(nil) - goto retry - } - - i := hash & mask - inc := uintptr(1) - for { - slot := apmSlotAt(slots, i) - slotVal := atomic.LoadPointer(&slot.val) - if slotVal == nil { - if (compare && oldVal != tombstone()) || newVal == tombstone() { - return nil - } - // Try to grab this slot for ourselves. - shard.dirtyMu.Lock() - slotVal = atomic.LoadPointer(&slot.val) - if slotVal == nil { - // Check if we need to rehash before dirtying a slot. - if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum { - shard.dirtyMu.Unlock() - shard.rehash(slots) - goto retry - } - slot.key = key - atomic.StorePointer(&slot.val, newVal) // transitions slot to full - shard.dirty++ - atomic.AddUintptr(&shard.count, 1) - shard.dirtyMu.Unlock() - return nil - } - // Raced with another store; the slot is no longer empty. Continue - // with the new value of slotVal since we may have raced with - // another store of key. - shard.dirtyMu.Unlock() - } - if slotVal == evacuated() { - // Racing with rehashing. - goto retry - } - if slot.key == key { - // We're reusing an existing slot, so rehashing isn't necessary. - for { - if (compare && oldVal != slotVal) || newVal == slotVal { - if slotVal == tombstone() { - return nil - } - return (*Value)(slotVal) - } - if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) { - if slotVal == tombstone() { - atomic.AddUintptr(&shard.count, 1) - return nil - } - if newVal == tombstone() { - atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */) - } - return (*Value)(slotVal) - } - slotVal = atomic.LoadPointer(&slot.val) - if slotVal == evacuated() { - goto retry - } - } - } - // This produces a triangular number sequence of offsets from the - // initially-probed position. - i = (i + inc) & mask - inc++ - } -} - -// rehash is marked nosplit to avoid preemption during table copying. -//go:nosplit -func (shard *apmShard) rehash(oldSlots unsafe.Pointer) { - shard.rehashMu.Lock() - defer shard.rehashMu.Unlock() - - if shard.slots != oldSlots { - // Raced with another call to rehash(). - return - } - - // Determine the size of the new table. Constraints: - // - // * The size of the table must be a power of two to ensure that every slot - // is visitable by every probe sequence under quadratic probing with - // triangular numbers. - // - // * The size of the table cannot decrease because even if shard.count is - // currently smaller than shard.dirty, concurrent stores that reuse - // existing slots can drive shard.count back up to a maximum of - // shard.dirty. - newSize := uintptr(8) // arbitrary initial size - if oldSlots != nil { - oldSize := shard.mask + 1 - newSize = oldSize - if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum { - newSize *= 2 - } - } - - // Allocate the new table. - newSlotsSlice := make([]apmSlot, newSize) - newSlotsHeader := (*gohacks.SliceHeader)(unsafe.Pointer(&newSlotsSlice)) - newSlots := newSlotsHeader.Data - newMask := newSize - 1 - - // Start a writer critical section now so that racing users of the old - // table that observe evacuated() wait for the new table. (But lock dirtyMu - // first since doing so may block, which we don't want to do during the - // writer critical section.) - shard.dirtyMu.Lock() - shard.seq.BeginWrite() - - if oldSlots != nil { - realCount := uintptr(0) - // Copy old entries to the new table. - oldMask := shard.mask - for i := uintptr(0); i <= oldMask; i++ { - oldSlot := apmSlotAt(oldSlots, i) - val := atomic.SwapPointer(&oldSlot.val, evacuated()) - if val == nil || val == tombstone() { - continue - } - hash := hasher.Hash(oldSlot.key) - j := hash & newMask - inc := uintptr(1) - for { - newSlot := apmSlotAt(newSlots, j) - if newSlot.val == nil { - newSlot.val = val - newSlot.key = oldSlot.key - break - } - j = (j + inc) & newMask - inc++ - } - realCount++ - } - // Update dirty to reflect that tombstones were not copied to the new - // table. Use realCount since a concurrent mutator may not have updated - // shard.count yet. - shard.dirty = realCount - } - - // Switch to the new table. - atomic.StorePointer(&shard.slots, newSlots) - atomic.StoreUintptr(&shard.mask, newMask) - - shard.seq.EndWrite() - shard.dirtyMu.Unlock() -} - -// Range invokes f on each Key-Value pair stored in m. If any call to f returns -// false, Range stops iteration and returns. -// -// Range does not necessarily correspond to any consistent snapshot of the -// Map's contents: no Key will be visited more than once, but if the Value for -// any Key is stored or deleted concurrently, Range may reflect any mapping for -// that Key from any point during the Range call. -// -// f must not call other methods on m. -func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) { - for si := 0; si < len(m.shards); si++ { - shard := &m.shards[si] - if !shard.doRange(f) { - return - } - } -} - -func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool { - // We have to lock rehashMu because if we handled races with rehashing by - // retrying, f could see the same key twice. - shard.rehashMu.Lock() - defer shard.rehashMu.Unlock() - slots := shard.slots - if slots == nil { - return true - } - mask := shard.mask - for i := uintptr(0); i <= mask; i++ { - slot := apmSlotAt(slots, i) - slotVal := atomic.LoadPointer(&slot.val) - if slotVal == nil || slotVal == tombstone() { - continue - } - if !f(slot.key, (*Value)(slotVal)) { - return false - } - } - return true -} - -// RangeRepeatable is like Range, but: -// -// * RangeRepeatable may visit the same Key multiple times in the presence of -// concurrent mutators, possibly passing different Values to f in different -// calls. -// -// * It is safe for f to call other methods on m. -func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) { - for si := 0; si < len(m.shards); si++ { - shard := &m.shards[si] - - retry: - epoch := shard.seq.BeginRead() - slots := atomic.LoadPointer(&shard.slots) - mask := atomic.LoadUintptr(&shard.mask) - if !shard.seq.ReadOk(epoch) { - goto retry - } - if slots == nil { - continue - } - - for i := uintptr(0); i <= mask; i++ { - slot := apmSlotAt(slots, i) - slotVal := atomic.LoadPointer(&slot.val) - if slotVal == evacuated() { - goto retry - } - if slotVal == nil || slotVal == tombstone() { - continue - } - if !f(slot.key, (*Value)(slotVal)) { - return - } - } - } -} diff --git a/pkg/sync/gate_test.go b/pkg/sync/gate_test.go deleted file mode 100644 index 82ce02b97..000000000 --- a/pkg/sync/gate_test.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sync - -import ( - "context" - "runtime" - "sync/atomic" - "testing" - "time" -) - -func TestGateBasic(t *testing.T) { - var g Gate - - if !g.Enter() { - t.Fatalf("Enter failed before Close") - } - g.Leave() - - g.Close() - if g.Enter() { - t.Fatalf("Enter succeeded after Close") - } -} - -func TestGateConcurrent(t *testing.T) { - // Each call to testGateConcurrentOnce tests behavior around a single call - // to Gate.Close, so run many short tests to increase the probability of - // flushing out any issues. - totalTime := 5 * time.Second - timePerTest := 20 * time.Millisecond - numTests := int(totalTime / timePerTest) - for i := 0; i < numTests; i++ { - testGateConcurrentOnce(t, timePerTest) - } -} - -func testGateConcurrentOnce(t *testing.T, d time.Duration) { - const numGoroutines = 1000 - - ctx, cancel := context.WithCancel(context.Background()) - var wg WaitGroup - defer func() { - cancel() - wg.Wait() - }() - - var g Gate - closeState := int32(0) // set to 1 before g.Close() and 2 after it returns - - // Start a large number of goroutines that repeatedly attempt to enter the - // gate and get the expected result. - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for ctx.Err() == nil { - closedBeforeEnter := atomic.LoadInt32(&closeState) == 2 - if g.Enter() { - closedBeforeLeave := atomic.LoadInt32(&closeState) == 2 - g.Leave() - if closedBeforeEnter { - t.Errorf("Enter succeeded after Close") - return - } - if closedBeforeLeave { - t.Errorf("Close returned before Leave") - return - } - } else { - if atomic.LoadInt32(&closeState) == 0 { - t.Errorf("Enter failed before Close") - return - } - } - // Go does not preempt busy loops until Go 1.14. - runtime.Gosched() - } - }() - } - - // Allow goroutines to enter the gate successfully for half of the test's - // duration, then close the gate and allow goroutines to fail to enter the - // gate for the remaining half. - time.Sleep(d / 2) - atomic.StoreInt32(&closeState, 1) - g.Close() - atomic.StoreInt32(&closeState, 2) - time.Sleep(d / 2) -} - -func BenchmarkGateEnterLeave(b *testing.B) { - var g Gate - for i := 0; i < b.N; i++ { - g.Enter() - g.Leave() - } -} - -func BenchmarkGateClose(b *testing.B) { - for i := 0; i < b.N; i++ { - var g Gate - g.Close() - } -} - -func BenchmarkGateEnterLeaveAsyncClose(b *testing.B) { - for i := 0; i < b.N; i++ { - var g Gate - g.Enter() - go func() { - g.Leave() - }() - g.Close() - } -} diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go deleted file mode 100644 index 9e4e3f0b2..000000000 --- a/pkg/sync/mutex_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package sync - -import ( - "sync" - "testing" - "unsafe" -) - -// TestStructSize verifies that syncMutex's size hasn't drifted from the -// standard library's version. -// -// The correctness of this package relies on these remaining in sync. -func TestStructSize(t *testing.T) { - const ( - got = unsafe.Sizeof(syncMutex{}) - want = unsafe.Sizeof(sync.Mutex{}) - ) - if got != want { - t.Errorf("got sizeof(syncMutex) = %d, want = sizeof(sync.Mutex) = %d", got, want) - } -} - -// TestFieldValues verifies that the semantics of syncMutex.state from the -// standard library's implementation. -// -// The correctness of this package relies on these remaining in sync. -func TestFieldValues(t *testing.T) { - var m Mutex - m.Lock() - if got := *m.m.state(); got != mutexLocked { - t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked) - } - m.Unlock() - if got := *m.m.state(); got != mutexUnlocked { - t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked) - } -} - -func TestDoubleTryLock(t *testing.T) { - var m Mutex - if !m.TryLock() { - t.Fatal("failed to aquire lock") - } - if m.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestTryLockAfterLock(t *testing.T) { - var m Mutex - m.Lock() - if m.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestTryLockUnlock(t *testing.T) { - var m Mutex - if !m.TryLock() { - t.Fatal("failed to aquire lock") - } - m.Unlock() // +checklocksforce - if !m.TryLock() { - t.Fatal("failed to aquire lock after unlock") - } -} diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go deleted file mode 100644 index 56a88e712..000000000 --- a/pkg/sync/rwmutex_test.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// GOMAXPROCS=10 go test - -// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the -// addition of downgradingWriter and the renaming of num_iterations to -// numIterations to shut up Golint. - -package sync - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" -) - -func parallelReader(m *RWMutex, clocked, cunlock, cdone chan bool) { - m.RLock() - clocked <- true - <-cunlock - m.RUnlock() - cdone <- true -} - -func doTestParallelReaders(numReaders, gomaxprocs int) { - runtime.GOMAXPROCS(gomaxprocs) - var m RWMutex - clocked := make(chan bool) - cunlock := make(chan bool) - cdone := make(chan bool) - for i := 0; i < numReaders; i++ { - go parallelReader(&m, clocked, cunlock, cdone) - } - // Wait for all parallel RLock()s to succeed. - for i := 0; i < numReaders; i++ { - <-clocked - } - for i := 0; i < numReaders; i++ { - cunlock <- true - } - // Wait for the goroutines to finish. - for i := 0; i < numReaders; i++ { - <-cdone - } -} - -func TestParallelReaders(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - doTestParallelReaders(1, 4) - doTestParallelReaders(3, 4) - doTestParallelReaders(4, 2) -} - -func reader(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.RLock() - n := atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func writer(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.Unlock() - } - cdone <- true -} - -func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.DowngradeLock() - n = atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { - runtime.GOMAXPROCS(gomaxprocs) - // Number of active readers + 10000 * number of active writers. - var activity int32 - var rwm RWMutex - cdone := make(chan bool) - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - var i int - for i = 0; i < numReaders/2; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - for ; i < numReaders; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - // Wait for the 4 writers and all readers to finish. - for i := 0; i < 4+numReaders; i++ { - <-cdone - } -} - -func TestDowngradableRWMutex(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - n := 1000 - if testing.Short() { - n = 5 - } - HammerDowngradableRWMutex(1, 1, n) - HammerDowngradableRWMutex(1, 3, n) - HammerDowngradableRWMutex(1, 10, n) - HammerDowngradableRWMutex(4, 1, n) - HammerDowngradableRWMutex(4, 3, n) - HammerDowngradableRWMutex(4, 10, n) - HammerDowngradableRWMutex(10, 1, n) - HammerDowngradableRWMutex(10, 3, n) - HammerDowngradableRWMutex(10, 10, n) - HammerDowngradableRWMutex(10, 5, n) -} - -func TestRWDoubleTryLock(t *testing.T) { - var rwm RWMutex - if !rwm.TryLock() { - t.Fatal("failed to aquire lock") - } - if rwm.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestRWTryLockAfterLock(t *testing.T) { - var rwm RWMutex - rwm.Lock() - if rwm.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestRWTryLockUnlock(t *testing.T) { - var rwm RWMutex - if !rwm.TryLock() { - t.Fatal("failed to aquire lock") - } - rwm.Unlock() // +checklocksforce - if !rwm.TryLock() { - t.Fatal("failed to aquire lock after unlock") - } -} - -func TestTryRLockAfterLock(t *testing.T) { - var rwm RWMutex - rwm.Lock() - if rwm.TryRLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestTryLockAfterRLock(t *testing.T) { - var rwm RWMutex - rwm.RLock() - if rwm.TryLock() { - t.Fatal("unexpectedly succeeded in aquiring locked mutex") - } -} - -func TestDoubleTryRLock(t *testing.T) { - var rwm RWMutex - if !rwm.TryRLock() { - t.Fatal("failed to aquire lock") - } - if !rwm.TryRLock() { - t.Fatal("failed to read aquire read locked lock") - } -} diff --git a/pkg/sync/seqatomic/BUILD b/pkg/sync/seqatomic/BUILD deleted file mode 100644 index 60f79ab54..000000000 --- a/pkg/sync/seqatomic/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -package(licenses = ["notice"]) - -go_template( - name = "generic_seqatomic", - srcs = ["generic_seqatomic_unsafe.go"], - types = [ - "Value", - ], - visibility = ["//:sandbox"], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - -go_template_instance( - name = "seqatomic_int", - out = "seqatomic_int_unsafe.go", - package = "seqatomic", - suffix = "Int", - template = ":generic_seqatomic", - types = { - "Value": "int", - }, -) - -go_library( - name = "seqatomic", - srcs = ["seqatomic_int_unsafe.go"], - deps = [ - "//pkg/gohacks", - "//pkg/sync", - ], -) - -go_test( - name = "seqatomic_test", - size = "small", - srcs = ["seqatomic_test.go"], - library = ":seqatomic", - deps = ["//pkg/sync"], -) diff --git a/pkg/sync/seqatomic/generic_seqatomic_unsafe.go b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go deleted file mode 100644 index 9578c9c52..000000000 --- a/pkg/sync/seqatomic/generic_seqatomic_unsafe.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package seqatomic doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package seqatomic - -import ( - "unsafe" - - "gvisor.dev/gvisor/pkg/gohacks" - "gvisor.dev/gvisor/pkg/sync" -) - -// Value is a required type parameter. -type Value struct{} - -// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in seq. -// -//go:nosplit -func SeqAtomicLoad(seq *sync.SeqCount, ptr *Value) Value { - for { - if val, ok := SeqAtomicTryLoad(seq, seq.BeginRead(), ptr); ok { - return val - } - } -} - -// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in seq initiated by a call to seq.BeginRead() that returned epoch. If the -// read would race with a writer critical section, SeqAtomicTryLoad returns -// (unspecified, false). -// -//go:nosplit -func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (val Value, ok bool) { - if sync.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, so it - // can't help us here. Instead, call runtime.memmove directly, which is - // not instrumented by the race detector. - gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - ok = seq.ReadOk(epoch) - return -} diff --git a/pkg/sync/seqatomic/seqatomic_test.go b/pkg/sync/seqatomic/seqatomic_test.go deleted file mode 100644 index 2c4568b07..000000000 --- a/pkg/sync/seqatomic/seqatomic_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqatomic - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq sync.SeqCount - const want = 1 - data := want - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq sync.SeqCount - var data int - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq sync.SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq sync.SeqCount - const want = 1 - data := want - epoch := seq.BeginRead() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } -} - -func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq sync.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } - seq.EndWrite() -} - -func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq sync.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } -} - -func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq sync.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := SeqAtomicLoadInt(&seq, &data); got != want { - b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } - } - }) -} - -func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq sync.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - epoch := seq.BeginRead() - for pb.Next() { - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } - } - }) -} - -// For comparison: -func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { - var a atomic.Value - const want = 42 - a.Store(int(want)) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := a.Load().(int); got != want { - b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) - } - } - }) -} diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go deleted file mode 100644 index 3f5592e3e..000000000 --- a/pkg/sync/seqcount_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package sync - -import ( - "testing" - "time" -) - -func TestSeqCountWriteUncontended(t *testing.T) { - var seq SeqCount - seq.BeginWrite() - seq.EndWrite() -} - -func TestSeqCountReadUncontended(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadAfterWrite(t *testing.T) { - var seq SeqCount - var data int32 - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadDuringWrite(t *testing.T) { - var seq SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountReadOkAfterWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } -} - -func TestSeqCountReadOkDuringWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } - seq.EndWrite() -} - -func BenchmarkSeqCountWriteUncontended(b *testing.B) { - var seq SeqCount - for i := 0; i < b.N; i++ { - seq.BeginWrite() - seq.EndWrite() - } -} - -func BenchmarkSeqCountReadUncontended(b *testing.B) { - var seq SeqCount - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - b.Fatalf("ReadOk: got false, wanted true") - } - } - }) -} diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD deleted file mode 100644 index 42c553308..000000000 --- a/pkg/syncevent/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -licenses(["notice"]) - -go_library( - name = "syncevent", - srcs = [ - "broadcaster.go", - "receiver.go", - "source.go", - "syncevent.go", - "waiter_unsafe.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/atomicbitops", - "//pkg/sync", - ], -) - -go_test( - name = "syncevent_test", - size = "small", - srcs = [ - "broadcaster_test.go", - "syncevent_example_test.go", - "waiter_test.go", - ], - library = ":syncevent", - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/waiter", - ], -) diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go deleted file mode 100644 index dabf08895..000000000 --- a/pkg/syncevent/broadcaster.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "gvisor.dev/gvisor/pkg/sync" -) - -// Broadcaster is an implementation of Source that supports any number of -// subscribed Receivers. -// -// The zero value of Broadcaster is valid and has no subscribed Receivers. -// Broadcaster is not copyable by value. -// -// All Broadcaster methods may be called concurrently from multiple goroutines. -type Broadcaster struct { - // Broadcaster is implemented as a hash table where keys are assigned by - // the Broadcaster and returned as SubscriptionIDs, making it safe to use - // the identity function for hashing. The hash table resolves collisions - // using linear probing and features Robin Hood insertion and backward - // shift deletion in order to support a relatively high load factor - // efficiently, which matters since the cost of Broadcast is linear in the - // size of the table. - - // mu protects the following fields. - mu sync.Mutex - - // Invariants: len(table) is 0 or a power of 2. - table []broadcasterSlot - - // load is the number of entries in table with receiver != nil. - load int - - lastID SubscriptionID -} - -type broadcasterSlot struct { - // Invariants: If receiver == nil, then filter == NoEvents and id == 0. - // Otherwise, id != 0. - receiver *Receiver - filter Set - id SubscriptionID -} - -const ( - broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1 - - broadcasterMaxLoadNum = 13 - broadcasterMaxLoadDen = 16 -) - -// SubscribeEvents implements Source.SubscribeEvents. -func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID { - b.mu.Lock() - - // Assign an ID for this subscription. - b.lastID++ - id := b.lastID - - // Expand the table if over the maximum load factor: - // - // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen - // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table) - b.load++ - if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) { - // Double the number of slots in the new table. - newlen := broadcasterMinNonZeroTableSize - if len(b.table) != 0 { - newlen = 2 * len(b.table) - } - if newlen <= cap(b.table) { - // Reuse excess capacity in the current table, moving entries not - // already in their first-probed positions to better ones. - newtable := b.table[:newlen] - newmask := uint64(newlen - 1) - for i := range b.table { - if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) { - entry := b.table[i] - b.table[i] = broadcasterSlot{} - broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter) - } - } - b.table = newtable - } else { - newtable := make([]broadcasterSlot, newlen) - // Copy existing entries to the new table. - for i := range b.table { - if b.table[i].receiver != nil { - broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter) - } - } - // Switch to the new table. - b.table = newtable - } - } - - broadcasterTableInsert(b.table, id, r, filter) - b.mu.Unlock() - return id -} - -// Preconditions: -// * table must not be full. -// * len(table) is a power of 2. -func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) { - entry := broadcasterSlot{ - receiver: r, - filter: filter, - id: id, - } - mask := uint64(len(table) - 1) - i := uint64(id) & mask - disp := uint64(0) - for { - if table[i].receiver == nil { - table[i] = entry - return - } - // If we've been displaced farther from our first-probed slot than the - // element stored in this one, swap elements and switch to inserting - // the replaced one. (This is Robin Hood insertion.) - slotDisp := (i - uint64(table[i].id)) & mask - if disp > slotDisp { - table[i], entry = entry, table[i] - disp = slotDisp - } - i = (i + 1) & mask - disp++ - } -} - -// UnsubscribeEvents implements Source.UnsubscribeEvents. -func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) { - b.mu.Lock() - - mask := uint64(len(b.table) - 1) - i := uint64(id) & mask - for { - if b.table[i].id == id { - // Found the element to remove. Move all subsequent elements - // backward until we either find an empty slot, or an element that - // is already in its first-probed slot. (This is backward shift - // deletion.) - for { - next := (i + 1) & mask - if b.table[next].receiver == nil { - break - } - if uint64(b.table[next].id)&mask == next { - break - } - b.table[i] = b.table[next] - i = next - } - b.table[i] = broadcasterSlot{} - break - } - i = (i + 1) & mask - } - - // If a table 1/4 of the current size would still be at or under the - // maximum load factor (i.e. the current table size is at least two - // expansions bigger than necessary), halve the size of the table to reduce - // the cost of Broadcast. Since we are concerned with iteration time and - // not memory usage, reuse the existing slice to reduce future allocations - // from table re-expansion. - b.load-- - if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) { - newlen := len(b.table) / 2 - newtable := b.table[:newlen] - for i := newlen; i < len(b.table); i++ { - if b.table[i].receiver != nil { - broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter) - b.table[i] = broadcasterSlot{} - } - } - b.table = newtable - } - - b.mu.Unlock() -} - -// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset -// of events to which they subscribed. The order in which Receivers are -// notified is unspecified. -func (b *Broadcaster) Broadcast(events Set) { - b.mu.Lock() - for i := range b.table { - if intersection := events & b.table[i].filter; intersection != 0 { - // We don't need to check if broadcasterSlot.receiver is nil, since - // if it is then broadcasterSlot.filter is 0. - b.table[i].receiver.Notify(intersection) - } - } - b.mu.Unlock() -} - -// FilteredEvents returns the set of events for which Broadcast will notify at -// least one Receiver, i.e. the union of filters for all subscribed Receivers. -func (b *Broadcaster) FilteredEvents() Set { - var es Set - b.mu.Lock() - for i := range b.table { - es |= b.table[i].filter - } - b.mu.Unlock() - return es -} diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go deleted file mode 100644 index e88779e23..000000000 --- a/pkg/syncevent/broadcaster_test.go +++ /dev/null @@ -1,376 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "fmt" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestBroadcasterFilter(t *testing.T) { - const numReceivers = 2 * MaxEvents - - var br Broadcaster - ws := make([]Waiter, numReceivers) - for i := range ws { - ws[i].Init() - br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents)) - } - for ev := 0; ev < MaxEvents; ev++ { - br.Broadcast(1 << ev) - for i := range ws { - want := NoEvents - if i%MaxEvents == ev { - want = 1 << ev - } - if got := ws[i].Receiver().PendingAndAckAll(); got != want { - t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want) - } - } - } -} - -// TestBroadcasterManySubscriptions tests that subscriptions are not lost by -// table expansion/compaction. -func TestBroadcasterManySubscriptions(t *testing.T) { - const numReceivers = 5000 // arbitrary - - var br Broadcaster - ws := make([]Waiter, numReceivers) - for i := range ws { - ws[i].Init() - } - - ids := make([]SubscriptionID, numReceivers) - for i := 0; i < numReceivers; i++ { - // Subscribe receiver i. - ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1) - // Check that receivers [0, i] are subscribed. - br.Broadcast(1) - for j := 0; j <= i; j++ { - if ws[j].Pending() != 1 { - t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i) - } - ws[j].Ack(1) - } - } - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numReceivers) - for i := 0; i < numReceivers; i++ { - // Unsubscribe receiver unsub[i]. - br.UnsubscribeEvents(ids[unsub[i]]) - // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that - // receivers (unsub[i], unsub[numReceivers]) are still subscribed. - br.Broadcast(1) - for j := 0; j <= i; j++ { - if ws[unsub[j]].Pending() != 0 { - t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i]) - } - } - for j := i + 1; j < numReceivers; j++ { - if ws[unsub[j]].Pending() != 1 { - t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i]) - } - ws[unsub[j]].Ack(1) - } - } -} - -var ( - receiverCountsNonZero = []int{1, 4, 16, 64} - receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...) -) - -// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage -// pattern X (described in terms of Broadcaster) with Broadcaster, a -// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively. - -// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe -// cycle. - -func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) { - var br Broadcaster - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - id := br.SubscribeEvents(w.Receiver(), 1) - br.UnsubscribeEvents(id) - } -} - -func BenchmarkMapSubscribeUnsubscribe(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mu.Lock() - m[w.Receiver()] = Set(1) - mu.Unlock() - mu.Lock() - delete(m, w.Receiver()) - mu.Unlock() - } -} - -func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) { - var q waiter.Queue - e, _ := waiter.NewChannelEntry(nil) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - q.EventRegister(&e, 1) - q.EventUnregister(&e) - } -} - -// BenchmarkXxxSubscribeUnsubscribeBatch is similar to -// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large -// number of Receivers at a time in order to measure the amortized overhead of -// table expansion/compaction. (Since waiter.Queue is implemented using a -// linked list, BenchmarkQueueSubscribeUnsubscribe and -// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same -// result.) - -const numBatchReceivers = 1000 - -func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) { - var br Broadcaster - ws := make([]Waiter, numBatchReceivers) - for i := range ws { - ws[i].Init() - } - ids := make([]SubscriptionID, numBatchReceivers) - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numBatchReceivers) - - b.ResetTimer() - for i := 0; i < b.N/numBatchReceivers; i++ { - for j := 0; j < numBatchReceivers; j++ { - ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1) - } - for j := 0; j < numBatchReceivers; j++ { - br.UnsubscribeEvents(ids[unsub[j]]) - } - } -} - -func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - ws := make([]Waiter, numBatchReceivers) - for i := range ws { - ws[i].Init() - } - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numBatchReceivers) - - b.ResetTimer() - for i := 0; i < b.N/numBatchReceivers; i++ { - for j := 0; j < numBatchReceivers; j++ { - mu.Lock() - m[ws[j].Receiver()] = Set(1) - mu.Unlock() - } - for j := 0; j < numBatchReceivers; j++ { - mu.Lock() - delete(m, ws[unsub[j]].Receiver()) - mu.Unlock() - } - } -} - -func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) { - var q waiter.Queue - es := make([]waiter.Entry, numBatchReceivers) - for i := range es { - es[i], _ = waiter.NewChannelEntry(nil) - } - - // Generate a random order for unsubscriptions. - unsub := rand.Perm(numBatchReceivers) - - b.ResetTimer() - for i := 0; i < b.N/numBatchReceivers; i++ { - for j := 0; j < numBatchReceivers; j++ { - q.EventRegister(&es[j], 1) - } - for j := 0; j < numBatchReceivers; j++ { - q.EventUnregister(&es[unsub[j]]) - } - } -} - -// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast -// already-pending events to multiple Receivers. - -func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) { - for _, n := range receiverCountsIncludingZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var br Broadcaster - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - br.SubscribeEvents(ws[i].Receiver(), 1) - } - br.Broadcast(1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - br.Broadcast(1) - } - }) - } -} - -func BenchmarkMapBroadcastRedundant(b *testing.B) { - for _, n := range receiverCountsIncludingZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - m[ws[i].Receiver()] = Set(1) - } - mu.Lock() - for r := range m { - r.Notify(1) - } - mu.Unlock() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mu.Lock() - for r := range m { - r.Notify(1) - } - mu.Unlock() - } - }) - } -} - -func BenchmarkQueueBroadcastRedundant(b *testing.B) { - for _, n := range receiverCountsIncludingZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var q waiter.Queue - for i := 0; i < n; i++ { - e, _ := waiter.NewChannelEntry(nil) - q.EventRegister(&e, 1) - } - q.Notify(1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - q.Notify(1) - } - }) - } -} - -// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to -// multiple Receivers, check that all Receivers have received the event, and -// clear the event from all Receivers. - -func BenchmarkBroadcasterBroadcastAck(b *testing.B) { - for _, n := range receiverCountsNonZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var br Broadcaster - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - br.SubscribeEvents(ws[i].Receiver(), 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - br.Broadcast(1) - for j := range ws { - if got, want := ws[j].Pending(), Set(1); got != want { - b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want) - } - ws[j].Ack(1) - } - } - }) - } -} - -func BenchmarkMapBroadcastAck(b *testing.B) { - for _, n := range receiverCountsNonZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var mu sync.Mutex - m := make(map[*Receiver]Set) - ws := make([]Waiter, n) - for i := range ws { - ws[i].Init() - m[ws[i].Receiver()] = Set(1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mu.Lock() - for r := range m { - r.Notify(1) - } - mu.Unlock() - for j := range ws { - if got, want := ws[j].Pending(), Set(1); got != want { - b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want) - } - ws[j].Ack(1) - } - } - }) - } -} - -func BenchmarkQueueBroadcastAck(b *testing.B) { - for _, n := range receiverCountsNonZero { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var q waiter.Queue - chs := make([]chan struct{}, n) - for i := range chs { - e, ch := waiter.NewChannelEntry(nil) - q.EventRegister(&e, 1) - chs[i] = ch - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - q.Notify(1) - for _, ch := range chs { - select { - case <-ch: - default: - b.Fatalf("channel did not receive event") - } - } - } - }) - } -} diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go deleted file mode 100644 index 5c86e5400..000000000 --- a/pkg/syncevent/receiver.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "sync/atomic" - - "gvisor.dev/gvisor/pkg/atomicbitops" -) - -// Receiver is an event sink that holds pending events and invokes a callback -// whenever new events become pending. Receiver's methods may be called -// concurrently from multiple goroutines. -// -// Receiver.Init() must be called before first use. -type Receiver struct { - // pending is the set of pending events. pending is accessed using atomic - // memory operations. - pending uint64 - - // cb is notified when new events become pending. cb is immutable after - // Init(). - cb ReceiverCallback -} - -// ReceiverCallback receives callbacks from a Receiver. -type ReceiverCallback interface { - // NotifyPending is called when the corresponding Receiver has new pending - // events. - // - // NotifyPending is called synchronously from Receiver.Notify(), so - // implementations must not take locks that may be held by callers of - // Receiver.Notify(). NotifyPending may be called concurrently from - // multiple goroutines. - NotifyPending() -} - -// Init must be called before first use of r. -func (r *Receiver) Init(cb ReceiverCallback) { - r.cb = cb -} - -// Pending returns the set of pending events. -func (r *Receiver) Pending() Set { - return Set(atomic.LoadUint64(&r.pending)) -} - -// Notify sets the given events as pending. -func (r *Receiver) Notify(es Set) { - p := Set(atomic.LoadUint64(&r.pending)) - // Optimization: Skip the atomic CAS on r.pending if all events are - // already pending. - if p&es == es { - return - } - // When this is uncontended (the common case), CAS is faster than - // atomic-OR because the former is inlined and the latter (which we - // implement in assembly ourselves) is not. - if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) { - // If the CAS fails, fall back to atomic-OR. - atomicbitops.OrUint64(&r.pending, uint64(es)) - } - r.cb.NotifyPending() -} - -// Ack unsets the given events as pending. -func (r *Receiver) Ack(es Set) { - p := Set(atomic.LoadUint64(&r.pending)) - // Optimization: Skip the atomic CAS on r.pending if all events are - // already not pending. - if p&es == 0 { - return - } - // When this is uncontended (the common case), CAS is faster than - // atomic-AND because the former is inlined and the latter (which we - // implement in assembly ourselves) is not. - if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) { - // If the CAS fails, fall back to atomic-AND. - atomicbitops.AndUint64(&r.pending, ^uint64(es)) - } -} - -// PendingAndAckAll unsets all events as pending and returns the set of -// previously-pending events. -// -// PendingAndAckAll should only be used in preference to a call to Pending -// followed by a conditional call to Ack when the caller expects events to be -// pending (e.g. after a call to ReceiverCallback.NotifyPending()). -func (r *Receiver) PendingAndAckAll() Set { - return Set(atomic.SwapUint64(&r.pending, 0)) -} diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go deleted file mode 100644 index d3d0f34c5..000000000 --- a/pkg/syncevent/source.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -// Source represents an event source. -type Source interface { - // SubscribeEvents causes the Source to notify the given Receiver of the - // given subset of events. - // - // Preconditions: - // * r != nil. - // * The ReceiverCallback for r must not take locks that are ordered - // prior to the Source; for example, it cannot call any Source - // methods. - SubscribeEvents(r *Receiver, filter Set) SubscriptionID - - // UnsubscribeEvents causes the Source to stop notifying the Receiver - // subscribed by a previous call to SubscribeEvents that returned the given - // SubscriptionID. - // - // Preconditions: UnsubscribeEvents may be called at most once for any - // given SubscriptionID. - UnsubscribeEvents(id SubscriptionID) -} - -// SubscriptionID identifies a call to Source.SubscribeEvents. -type SubscriptionID uint64 - -// UnsubscribeAndAck is a convenience function that unsubscribes r from the -// given events from src and also clears them from r. -func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) { - src.UnsubscribeEvents(id) - r.Ack(filter) -} - -// NoopSource implements Source by never sending events to subscribed -// Receivers. -type NoopSource struct{} - -// SubscribeEvents implements Source.SubscribeEvents. -func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID { - return 0 -} - -// UnsubscribeEvents implements Source.UnsubscribeEvents. -func (NoopSource) UnsubscribeEvents(SubscriptionID) { -} - -// See Broadcaster for a non-noop implementations of Source. diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go deleted file mode 100644 index 9fb6a06de..000000000 --- a/pkg/syncevent/syncevent.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package syncevent provides efficient primitives for goroutine -// synchronization based on event bitmasks. -package syncevent - -// Set is a bitmask where each bit represents a distinct user-defined event. -// The event package does not treat any bits in Set specially. -type Set uint64 - -const ( - // NoEvents is a Set containing no events. - NoEvents = Set(0) - - // AllEvents is a Set containing all possible events. - AllEvents = ^Set(0) - - // MaxEvents is the number of distinct events that can be represented by a Set. - MaxEvents = 64 -) diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go deleted file mode 100644 index bfb18e2ea..000000000 --- a/pkg/syncevent/syncevent_example_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "fmt" - "sync/atomic" - "time" -) - -func Example_ioReadinessInterrputible() { - const ( - evReady = Set(1 << iota) - evInterrupt - ) - errNotReady := fmt.Errorf("not ready for I/O") - - // State of some I/O object. - var ( - br Broadcaster - ready uint32 - ) - doIO := func() error { - if atomic.LoadUint32(&ready) == 0 { - return errNotReady - } - return nil - } - go func() { - // The I/O object eventually becomes ready for I/O. - time.Sleep(100 * time.Millisecond) - // When it does, it first ensures that future calls to isReady() return - // true, then broadcasts the readiness event to Receivers. - atomic.StoreUint32(&ready, 1) - br.Broadcast(evReady) - }() - - // Each user of the I/O object owns a Waiter. - var w Waiter - w.Init() - // The Waiter may be asynchronously interruptible, e.g. for signal - // handling in the sentry. - go func() { - time.Sleep(200 * time.Millisecond) - w.Receiver().Notify(evInterrupt) - }() - - // To use the I/O object: - // - // Optionally, if the I/O object is likely to be ready, attempt I/O first. - err := doIO() - if err == nil { - // Success, we're done. - return /* nil */ - } - if err != errNotReady { - // Failure, I/O failed for some reason other than readiness. - return /* err */ - } - // Subscribe for readiness events from the I/O object. - id := br.SubscribeEvents(w.Receiver(), evReady) - // When we are finished blocking, unsubscribe from readiness events and - // remove readiness events from the pending event set. - defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id) - for { - // Attempt I/O again. This must be done after the call to SubscribeEvents, - // since the I/O object might have become ready between the previous call - // to doIO and the call to SubscribeEvents. - err = doIO() - if err == nil { - return /* nil */ - } - if err != errNotReady { - return /* err */ - } - // Block until either the I/O object indicates it is ready, or we are - // interrupted. - events := w.Wait() - if events&evInterrupt != 0 { - // In the specific case of sentry signal handling, signal delivery - // is handled by another system, so we aren't responsible for - // acknowledging evInterrupt. - return /* errInterrupted */ - } - // Note that, in a concurrent context, the I/O object might become - // ready and then not ready again. To handle this: - // - // - evReady must be acknowledged before calling doIO() again (rather - // than after), so that if the I/O object becomes ready *again* after - // the call to doIO(), the readiness event is not lost. - // - // - We must loop instead of just calling doIO() once after receiving - // evReady. - w.Ack(evReady) - } -} diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go deleted file mode 100644 index 3c8cbcdd8..000000000 --- a/pkg/syncevent/waiter_test.go +++ /dev/null @@ -1,414 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/sync" -) - -func TestWaiterAlreadyPending(t *testing.T) { - var w Waiter - w.Init() - want := Set(1) - w.Notify(want) - if got := w.Wait(); got != want { - t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want) - } -} - -func TestWaiterAsyncNotify(t *testing.T) { - var w Waiter - w.Init() - want := Set(1) - go func() { - time.Sleep(100 * time.Millisecond) - w.Notify(want) - }() - if got := w.Wait(); got != want { - t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want) - } -} - -func TestWaiterWaitFor(t *testing.T) { - var w Waiter - w.Init() - evWaited := Set(1) - evOther := Set(2) - w.Notify(evOther) - notifiedEvent := uint32(0) - go func() { - time.Sleep(100 * time.Millisecond) - atomic.StoreUint32(¬ifiedEvent, 1) - w.Notify(evWaited) - }() - if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want { - t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want) - } - if atomic.LoadUint32(¬ifiedEvent) == 0 { - t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event") - } -} - -func TestWaiterWaitAndAckAll(t *testing.T) { - var w Waiter - w.Init() - w.Notify(AllEvents) - if got := w.WaitAndAckAll(); got != AllEvents { - t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents) - } - if got := w.Pending(); got != NoEvents { - t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got) - } -} - -// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage -// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and -// buffered chan struct{} respectively. When the maximum number of event -// sources is relevant, we use 3 event sources because this is representative -// of the kernel.Task.block() use case: an interrupt source, a timeout source, -// and the actual event source being waited on. - -// Event set used by most benchmarks. -const evBench Set = 1 - -// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of -// an event that is already pending. - -func BenchmarkWaiterNotifyRedundant(b *testing.B) { - var w Waiter - w.Init() - w.Notify(evBench) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Notify(evBench) - } -} - -func BenchmarkSleeperNotifyRedundant(b *testing.B) { - var s sleep.Sleeper - var w sleep.Waker - s.AddWaker(&w, 0) - w.Assert() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - } -} - -func BenchmarkChannelNotifyRedundant(b *testing.B) { - ch := make(chan struct{}, 1) - ch <- struct{}{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - select { - case ch <- struct{}{}: - default: - } - } -} - -// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an -// event, return that event using a blocking check, and then unset the event as -// pending. - -func BenchmarkWaiterNotifyWaitAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Notify(evBench) - w.Wait() - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyWaitAck(b *testing.B) { - var s sleep.Sleeper - var w sleep.Waker - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - s.Fetch(true) - } -} - -func BenchmarkChannelNotifyWaitAck(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // notify - select { - case ch <- struct{}{}: - default: - } - - // wait + ack - <-ch - } -} - -// BenchmarkSleeperMultiNotifyWaitAck is equivalent to -// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a -// temporary sleep.Waker. This is necessary when multiple goroutines may wait -// for the same event, since each sleep.Waker can wake only a single -// sleep.Sleeper. -// -// The syncevent package does not require a distinct object for each -// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and -// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state -// for channels, runtime.sudog, is inescapably runtime-allocated, so -// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would -// also be identical. - -func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) { - var s sleep.Sleeper - // The sleep package doesn't provide sync.Pool allocation of Wakers; - // we do for a fairer comparison. - wakerPool := sync.Pool{ - New: func() interface{} { - return &sleep.Waker{} - }, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w := wakerPool.Get().(*sleep.Waker) - s.AddWaker(w, 0) - w.Assert() - s.Fetch(true) - s.Done() - wakerPool.Put(w) - } -} - -// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also -// includes allocation of a temporary Waiter. This models the case where a -// goroutine not already associated with a Waiter needs one in order to block. -// -// The analogous state for channels is built into runtime.g, so -// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be -// identical. - -func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - w := GetWaiter() - w.Notify(evBench) - w.Wait() - w.Ack(evBench) - PutWaiter(w) - } -} - -func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) { - // The sleep package doesn't provide sync.Pool allocation of Sleepers; - // we do for a fairer comparison. - sleeperPool := sync.Pool{ - New: func() interface{} { - return &sleep.Sleeper{} - }, - } - var w sleep.Waker - - b.ResetTimer() - for i := 0; i < b.N; i++ { - s := sleeperPool.Get().(*sleep.Sleeper) - s.AddWaker(&w, 0) - w.Assert() - s.Fetch(true) - s.Done() - sleeperPool.Put(s) - } -} - -// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows -// for multiple event sources. - -func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Notify(evBench) - if e := w.Wait(); e != evBench { - b.Fatalf("Wait: got %#x, wanted %#x", e, evBench) - } - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) { - var s sleep.Sleeper - var ws [3]sleep.Waker - for i := range ws { - s.AddWaker(&ws[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ws[0].Assert() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, wanted 0", id) - } - } -} - -func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) { - ch0 := make(chan struct{}, 1) - ch1 := make(chan struct{}, 1) - ch2 := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // notify - select { - case ch0 <- struct{}{}: - default: - } - - // wait + clear - select { - case <-ch0: - // ok - case <-ch1: - b.Fatalf("received from ch1") - case <-ch2: - b.Fatalf("received from ch2") - } - } -} - -// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an -// event while another goroutine signals the event. This assumes that a new -// goroutine doesn't run immediately (i.e. the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). - -func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w.Notify(1) - }() - w.Wait() - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) { - var s sleep.Sleeper - var w sleep.Waker - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w.Assert() - }() - s.Fetch(true) - } -} - -func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - select { - case ch <- struct{}{}: - default: - } - }() - <-ch - } -} - -// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but -// allows for multiple event sources. - -func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) { - var w Waiter - w.Init() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w.Notify(evBench) - }() - if e := w.Wait(); e != evBench { - b.Fatalf("Wait: got %#x, wanted %#x", e, evBench) - } - w.Ack(evBench) - } -} - -func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { - var s sleep.Sleeper - var ws [3]sleep.Waker - for i := range ws { - s.AddWaker(&ws[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - ws[0].Assert() - }() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, expected 0", id) - } - } -} - -func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) { - ch0 := make(chan struct{}, 1) - ch1 := make(chan struct{}, 1) - ch2 := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - select { - case ch0 <- struct{}{}: - default: - } - }() - - select { - case <-ch0: - // ok - case <-ch1: - b.Fatalf("received from ch1") - case <-ch2: - b.Fatalf("received from ch2") - } - } -} diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go deleted file mode 100644 index b6ed2852d..000000000 --- a/pkg/syncevent/waiter_unsafe.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncevent - -import ( - "sync/atomic" - "unsafe" - - "gvisor.dev/gvisor/pkg/sync" -) - -// Waiter allows a goroutine to block on pending events received by a Receiver. -// -// Waiter.Init() must be called before first use. -type Waiter struct { - r Receiver - - // g is one of: - // - // - 0: No goroutine is blocking in Wait. - // - // - preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet - // completed waiterUnlock(). Thus the wait can only be interrupted by - // replacing the value of g with 0 (the G may not be in state Gwaiting yet, - // so we can't call goready.) - // - // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the - // goroutine blocked in Wait, which can only be woken by calling goready. - g uintptr `state:"zerovalue"` -} - -const preparingG = 1 - -// Init must be called before first use of w. -func (w *Waiter) Init() { - w.r.Init(w) -} - -// Receiver returns the Receiver that receives events that unblock calls to -// w.Wait(). -func (w *Waiter) Receiver() *Receiver { - return &w.r -} - -// Pending returns the set of pending events. -func (w *Waiter) Pending() Set { - return w.r.Pending() -} - -// Wait blocks until at least one event is pending, then returns the set of -// pending events. It does not affect the set of pending events; callers must -// call w.Ack() to do so, or use w.WaitAndAck() instead. -// -// Precondition: Only one goroutine may call any Wait* method at a time. -func (w *Waiter) Wait() Set { - return w.WaitFor(AllEvents) -} - -// WaitFor blocks until at least one event in es is pending, then returns the -// set of pending events (including those not in es). It does not affect the -// set of pending events; callers must call w.Ack() to do so. -// -// Precondition: Only one goroutine may call any Wait* method at a time. -func (w *Waiter) WaitFor(es Set) Set { - for { - // Optimization: Skip the atomic store to w.g if an event is already - // pending. - if p := w.r.Pending(); p&es != NoEvents { - return p - } - - // Indicate that we're preparing to go to sleep. - atomic.StoreUintptr(&w.g, preparingG) - - // If an event is pending, abort the sleep. - if p := w.r.Pending(); p&es != NoEvents { - atomic.StoreUintptr(&w.g, 0) - return p - } - - // If w.g is still preparingG (i.e. w.NotifyPending() has not been - // called or has not reached atomic.SwapUintptr()), go to sleep until - // w.NotifyPending() => goready(). - sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0) - } -} - -//go:norace -//go:nosplit -func waiterCommit(g uintptr, wg unsafe.Pointer) bool { - // The only way this CAS can fail is if a call to Waiter.NotifyPending() - // has replaced *wg with nil, in which case we should not sleep. - return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(wg), preparingG, g) -} - -// Ack marks the given events as not pending. -func (w *Waiter) Ack(es Set) { - w.r.Ack(es) -} - -// WaitAndAckAll blocks until at least one event is pending, then marks all -// events as not pending and returns the set of previously-pending events. -// -// Precondition: Only one goroutine may call any Wait* method at a time. -func (w *Waiter) WaitAndAckAll() Set { - // Optimization: Skip the atomic store to w.g if an event is already - // pending. Call Pending() first since, in the common case that events are - // not yet pending, this skips an atomic swap on w.r.pending. - if w.r.Pending() != NoEvents { - if p := w.r.PendingAndAckAll(); p != NoEvents { - return p - } - } - - for { - // Indicate that we're preparing to go to sleep. - atomic.StoreUintptr(&w.g, preparingG) - - // If an event is pending, abort the sleep. - if w.r.Pending() != NoEvents { - if p := w.r.PendingAndAckAll(); p != NoEvents { - atomic.StoreUintptr(&w.g, 0) - return p - } - } - - // If w.g is still preparingG (i.e. w.NotifyPending() has not been - // called or has not reached atomic.SwapUintptr()), go to sleep until - // w.NotifyPending() => goready(). - sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0) - - // Check for pending events. We call PendingAndAckAll() directly now since - // we only expect to be woken after events become pending. - if p := w.r.PendingAndAckAll(); p != NoEvents { - return p - } - } -} - -// Notify marks the given events as pending, possibly unblocking concurrent -// calls to w.Wait() or w.WaitFor(). -func (w *Waiter) Notify(es Set) { - w.r.Notify(es) -} - -// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter -// should not call NotifyPending. -func (w *Waiter) NotifyPending() { - // Optimization: Skip the atomic swap on w.g if there is no sleeping - // goroutine. NotifyPending is called after w.r.Pending() is updated, so - // concurrent and future calls to w.Wait() will observe pending events and - // abort sleeping. - if atomic.LoadUintptr(&w.g) == 0 { - return - } - // Wake a sleeping G, or prevent a G that is preparing to sleep from doing - // so. Swap is needed here to ensure that only one call to NotifyPending - // calls goready. - if g := atomic.SwapUintptr(&w.g, 0); g > preparingG { - sync.Goready(g, 0) - } -} - -var waiterPool = sync.Pool{ - New: func() interface{} { - w := &Waiter{} - w.Init() - return w - }, -} - -// GetWaiter returns an unused Waiter. PutWaiter should be called to release -// the Waiter once it is no longer needed. -// -// Where possible, users should prefer to associate each goroutine that calls -// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of -// Waiters in hot paths. -func GetWaiter() *Waiter { - return waiterPool.Get().(*Waiter) -} - -// PutWaiter releases an unused Waiter previously returned by GetWaiter. -func PutWaiter(w *Waiter) { - waiterPool.Put(w) -} diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD deleted file mode 100644 index 1cd5d641d..000000000 --- a/pkg/syserr/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "syserr", - srcs = [ - "host_linux.go", - "netstack.go", - "syserr.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux/errno", - "//pkg/errors", - "//pkg/errors/linuxerr", - "//pkg/tcpip", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/syserr/syserr_linux_state_autogen.go b/pkg/syserr/syserr_linux_state_autogen.go new file mode 100644 index 000000000..90e796197 --- /dev/null +++ b/pkg/syserr/syserr_linux_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package syserr diff --git a/pkg/syserr/syserr_state_autogen.go b/pkg/syserr/syserr_state_autogen.go new file mode 100644 index 000000000..712631a64 --- /dev/null +++ b/pkg/syserr/syserr_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package syserr diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD deleted file mode 100644 index f00cfd0f5..000000000 --- a/pkg/tcpip/BUILD +++ /dev/null @@ -1,99 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools:deps.bzl", "deps_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "sock_err_list", - out = "sock_err_list.go", - package = "tcpip", - prefix = "sockError", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*SockError", - "Linker": "*SockError", - }, -) - -go_library( - name = "tcpip", - srcs = [ - "errors.go", - "sock_err_list.go", - "socketops.go", - "stdclock.go", - "stdclock_state.go", - "tcpip.go", - "timer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/atomicbitops", - "//pkg/sync", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) - -deps_test( - name = "netstack_deps_test", - allowed = [ - # gVisor deps. - "//pkg/atomicbitops", - "//pkg/buffer", - "//pkg/context", - "//pkg/gohacks", - "//pkg/goid", - "//pkg/ilist", - "//pkg/linewriter", - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/state", - "//pkg/state/wire", - "//pkg/sync", - "//pkg/waiter", - - # Other deps. - "@com_github_google_btree//:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - "@org_golang_x_time//rate:go_default_library", - ], - allowed_prefixes = [ - "//pkg/tcpip", - "@org_golang_x_sys//internal/unsafeheader", - ], - targets = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/packetsocket", - "//pkg/tcpip/link/qdisc/fifo", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - ], -) - -go_test( - name = "tcpip_test", - size = "small", - srcs = ["tcpip_test.go"], - library = ":tcpip", - deps = ["@com_github_google_go_cmp//cmp:go_default_library"], -) - -go_test( - name = "tcpip_x_test", - size = "small", - srcs = ["timer_test.go"], - deps = [":tcpip"], -) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD deleted file mode 100644 index a984f1712..000000000 --- a/pkg/tcpip/adapters/gonet/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gonet", - srcs = ["gonet.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) - -go_test( - name = "gonet_test", - size = "small", - srcs = ["gonet_test.go"], - library = ":gonet", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@org_golang_x_net//nettest:go_default_library", - ], -) diff --git a/pkg/tcpip/adapters/gonet/gonet_state_autogen.go b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go new file mode 100644 index 000000000..7a5c5419e --- /dev/null +++ b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package gonet diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go deleted file mode 100644 index 48b24692b..000000000 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ /dev/null @@ -1,725 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gonet - -import ( - "context" - "fmt" - "io" - "net" - "reflect" - "strings" - "testing" - "time" - - "golang.org/x/net/nettest" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - NICID = 1 -) - -func TestTimeouts(t *testing.T) { - nc := NewTCPConn(nil, nil) - dlfs := []struct { - name string - f func(time.Time) error - }{ - {"SetDeadline", nc.SetDeadline}, - {"SetReadDeadline", nc.SetReadDeadline}, - {"SetWriteDeadline", nc.SetWriteDeadline}, - } - - for _, dlf := range dlfs { - if err := dlf.f(time.Time{}); err != nil { - t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil) - } - } -} - -func newLoopbackStack() (*stack.Stack, tcpip.Error) { - // Create the stack and add a NIC. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, - }) - - if err := s.CreateNIC(NICID, loopback.New()); err != nil { - return nil, err - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - // IPv4 - { - Destination: header.IPv4EmptySubnet, - NIC: NICID, - }, - - // IPv6 - { - Destination: header.IPv6EmptySubnet, - NIC: NICID, - }, - }) - - return s, nil -} - -type testConnection struct { - wq *waiter.Queue - e *waiter.Entry - ch chan struct{} - ep tcpip.Endpoint -} - -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) { - wq := &waiter.Queue{} - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - return nil, err - } - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.WritableEvents) - - err = ep.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); ok { - <-ch - err = ep.LastError() - } - if err != nil { - return nil, err - } - - wq.EventUnregister(&entry) - wq.EventRegister(&entry, waiter.ReadableEvents) - - return &testConnection{wq, &entry, ch, ep}, nil -} - -func (c *testConnection) close() { - c.wq.EventUnregister(c.e) - c.ep.Close() -} - -// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock. -func TestCloseReader(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Errorf("l.Accept() = %v", err) - // Cannot call Fatalf in goroutine. Just return from the goroutine. - return - } - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - if n != 0 || err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when -// using tcp.Forwarder. -func TestCloseReaderWithForwarder(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - done := make(chan struct{}) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) - } - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestCloseRead(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - _, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - // Endpoint will be closed in deferred s.Close (above). - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewTCPConn(tc.wq, tc.ep) - - if err := c.CloseRead(); err != nil { - t.Errorf("c.CloseRead() = %v", err) - } - - buf := make([]byte, 256) - if n, err := c.Read(buf); err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err) - } - - if n, err := c.Write([]byte("abc123")); n != 6 || err != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err) - } -} - -func TestCloseWrite(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - n, e := c.Read(make([]byte, 256)) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e) - } - - if n, e = c.Write([]byte("abc123")); n != 6 || e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewTCPConn(tc.wq, tc.ep) - - if err := c.CloseWrite(); err != nil { - t.Errorf("c.CloseWrite() = %v", err) - } - - buf := make([]byte, 256) - n, err := c.Read(buf) - if err != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err) - } - - n, err = c.Write([]byte("abc123")) - got, ok := err.(*net.OpError) - want := "endpoint is closed for send" - if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) { - t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want) - } -} - -func TestUDPForwarder(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - defer func() { - s.Close() - s.Wait() - }() - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - done := make(chan struct{}) - fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil { - t.Errorf("c.Read() = %v", e) - } - - if _, e := c.Write(buf[:n]); e != nil { - t.Errorf("c.Write() = %v", e) - } - }) - s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - sent := "abc123" - sendAddr := fullToUDPAddr(addr1) - if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - - buf := make([]byte, 256) - n, recvAddr, err := c2.ReadFrom(buf) - if err != nil || recvAddr.String() != sendAddr.String() { - t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil) - } -} - -// TestDeadlineChange tests that changing the deadline affects currently blocked reads. -func TestDeadlineChange(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Errorf("l.Accept() = %v", err) - // Cannot call Fatalf in goroutine. Just return from the goroutine. - return - } - - c.SetDeadline(time.Now().Add(time.Minute)) - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.SetDeadline(time.Now().Add(time.Millisecond * 10)) - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := "i/o timeout" - if n != 0 || !ok || got.Err == nil || got.Err.Error() != want { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(time.Millisecond * 500): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - sendAddr := fullToUDPAddr(addr2) - if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, recvAddr, err := c2.ReadFrom(recv) - if err != nil || n != len(recv) { - t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) { - t.Errorf("got recvAddr = %v, want = %v", recvAddr, want) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func TestConnectedPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 4):", err) - } - c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("DialUDP(bind port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, err := c1.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func makePipe() (c1, c2 net.Conn, stop func(), err error) { - s, e := newLoopbackStack() - if e != nil { - return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) - } - - c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) - } - - c2, err = l.Accept() - if err != nil { - l.Close() - c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) - } - - stop = func() { - c1.Close() - c2.Close() - s.Close() - s.Wait() - } - - if err := l.Close(); err != nil { - stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) - } - - return c1, c2, stop, nil -} - -func TestTCPConnTransfer(t *testing.T) { - c1, c2, _, err := makePipe() - if err != nil { - t.Fatal(err) - } - defer func() { - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } - }() - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - const sent = "abc123" - - tests := []struct { - name string - c1 net.Conn - c2 net.Conn - }{ - {"connected to accepted", c1, c2}, - {"accepted to connected", c2, c1}, - } - - for _, test := range tests { - if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil) - continue - } - - recv := make([]byte, len(sent)) - n, err := test.c2.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil) - continue - } - - if recv := string(recv); recv != sent { - t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent) - } - } -} - -func TestTCPDialError(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - defer func() { - s.Close() - s.Wait() - }() - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - - switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) { - case *net.OpError: - if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() { - t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{}) - } - default: - t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{}) - } -} - -func TestDialContextTCPCanceled(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled) - } -} - -func TestDialContextTCPTimeout(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - defer func() { - s.Close() - s.Wait() - }() - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - time.Sleep(time.Second) - r.Complete(true) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - ctx := context.Background() - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) - defer cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded) - } -} - -func TestNetTest(t *testing.T) { - nettest.TestConn(t, makePipe) -} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD deleted file mode 100644 index 23aa0ad05..000000000 --- a/pkg/tcpip/buffer/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "buffer", - srcs = [ - "prependable.go", - "view.go", - "view_unsafe.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "buffer_x_test", - size = "small", - srcs = [ - "view_test.go", - ], - deps = [ - ":buffer", - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go new file mode 100644 index 000000000..51bfbff8a --- /dev/null +++ b/pkg/tcpip/buffer/buffer_state_autogen.go @@ -0,0 +1,39 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (vv *VectorisedView) StateTypeName() string { + return "pkg/tcpip/buffer.VectorisedView" +} + +func (vv *VectorisedView) StateFields() []string { + return []string{ + "views", + "size", + } +} + +func (vv *VectorisedView) beforeSave() {} + +// +checklocksignore +func (vv *VectorisedView) StateSave(stateSinkObject state.Sink) { + vv.beforeSave() + stateSinkObject.Save(0, &vv.views) + stateSinkObject.Save(1, &vv.size) +} + +func (vv *VectorisedView) afterLoad() {} + +// +checklocksignore +func (vv *VectorisedView) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &vv.views) + stateSourceObject.Load(1, &vv.size) +} + +func init() { + state.Register((*VectorisedView)(nil)) +} diff --git a/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go b/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go new file mode 100644 index 000000000..5a5c40722 --- /dev/null +++ b/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package buffer diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go deleted file mode 100644 index d296d9c2b..000000000 --- a/pkg/tcpip/buffer/view_test.go +++ /dev/null @@ -1,629 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package buffer_test contains tests for the buffer.VectorisedView type. -package buffer_test - -import ( - "bytes" - "io" - "reflect" - "testing" - "unsafe" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// copy returns a deep-copy of the vectorised view. -func copyVV(vv buffer.VectorisedView) buffer.VectorisedView { - views := make([]buffer.View, 0, len(vv.Views())) - for _, v := range vv.Views() { - views = append(views, append(buffer.View(nil), v...)) - } - return buffer.NewVectorisedView(vv.Size(), views) -} - -// vv is an helper to build buffer.VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -// v returns a buffer.View containing piece. -func v(piece string) buffer.View { - return buffer.View(piece) -} - -var capLengthTestCases = []struct { - comment string - in buffer.VectorisedView - length int - want buffer.VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Case spanning across two Views", - in: vv(4, "123", "4"), - length: 2, - want: vv(2, "12"), - }, - { - comment: "Corner case with negative length", - in: vv(1, "1"), - length: -1, - want: vv(0), - }, - { - comment: "Corner case with length = 0", - in: vv(3, "12", "3"), - length: 0, - want: vv(0), - }, - { - comment: "Corner case with length = size", - in: vv(1, "1"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Corner case with length > size", - in: vv(1, "1"), - length: 2, - want: vv(1, "1"), - }, -} - -func TestCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - orig := copyVV(c.in) - c.in.CapLength(c.length) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", - c.comment, c.length, orig, c.in, c.want) - } - } -} - -var trimFrontTestCases = []struct { - comment string - in buffer.VectorisedView - count int - want buffer.VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case where we trim an entire View", - in: vv(2, "1", "2"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: vv(1, "3"), - }, - { - comment: "Case with one empty Views", - in: vv(3, "1", "", "23"), - count: 2, - want: vv(1, "3"), - }, - { - comment: "Corner case with negative count", - in: vv(1, "1"), - count: -1, - want: vv(1, "1"), - }, - { - comment: " Corner case with count = 0", - in: vv(1, "1"), - count: 0, - want: vv(1, "1"), - }, - { - comment: "Corner case with count = size", - in: vv(1, "1"), - count: 1, - want: vv(0), - }, - { - comment: "Corner case with count > size", - in: vv(1, "1"), - count: 2, - want: vv(0), - }, -} - -func TestTrimFront(t *testing.T) { - for _, c := range trimFrontTestCases { - orig := copyVV(c.in) - c.in.TrimFront(c.count) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", - c.comment, c.count, orig, c.in, c.want) - } - } -} - -var toViewCases = []struct { - comment string - in buffer.VectorisedView - want buffer.View -}{ - { - comment: "Simple case", - in: vv(2, "12"), - want: []byte("12"), - }, - { - comment: "Case with multiple views", - in: vv(2, "1", "2"), - want: []byte("12"), - }, - { - comment: "Empty case", - in: vv(0), - want: []byte(""), - }, -} - -func TestToView(t *testing.T) { - for _, c := range toViewCases { - got := c.in.ToView() - if !reflect.DeepEqual(got, c.want) { - t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v", - c.comment, c.in, got, c.want) - } - } -} - -var toCloneCases = []struct { - comment string - inView buffer.VectorisedView - inBuffer []buffer.View -}{ - { - comment: "Simple case", - inView: vv(1, "1"), - inBuffer: make([]buffer.View, 1), - }, - { - comment: "Case with multiple views", - inView: vv(2, "1", "2"), - inBuffer: make([]buffer.View, 2), - }, - { - comment: "Case with buffer too small", - inView: vv(2, "1", "2"), - inBuffer: make([]buffer.View, 1), - }, - { - comment: "Case with buffer larger than needed", - inView: vv(1, "1"), - inBuffer: make([]buffer.View, 2), - }, - { - comment: "Case with nil buffer", - inView: vv(1, "1"), - inBuffer: nil, - }, -} - -func TestToClone(t *testing.T) { - for _, c := range toCloneCases { - t.Run(c.comment, func(t *testing.T) { - got := c.inView.Clone(c.inBuffer) - if !reflect.DeepEqual(got, c.inView) { - t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", - c.inView, c.inBuffer, got, c.inView) - } - }) - } -} - -type readToTestCases struct { - comment string - vv buffer.VectorisedView - bytesToRead int - wantBytes string - leftVV buffer.VectorisedView -} - -func createReadToTestCases() []readToTestCases { - return []readToTestCases{ - { - comment: "large VV, short read", - vv: vv(30, "012345678901234567890123456789"), - bytesToRead: 10, - wantBytes: "0123456789", - leftVV: vv(20, "01234567890123456789"), - }, - { - comment: "largeVV, multiple views, short read", - vv: vv(13, "123", "345", "567", "8910"), - bytesToRead: 6, - wantBytes: "123345", - leftVV: vv(7, "567", "8910"), - }, - { - comment: "smallVV (multiple views), large read", - vv: vv(3, "1", "2", "3"), - bytesToRead: 10, - wantBytes: "123", - leftVV: vv(0, ""), - }, - { - comment: "smallVV (single view), large read", - vv: vv(1, "1"), - bytesToRead: 10, - wantBytes: "1", - leftVV: vv(0, ""), - }, - { - comment: "emptyVV, large read", - vv: vv(0, ""), - bytesToRead: 10, - wantBytes: "", - leftVV: vv(0, ""), - }, - } -} - -func TestVVReadToVV(t *testing.T) { - for _, tc := range createReadToTestCases() { - t.Run(tc.comment, func(t *testing.T) { - var readTo buffer.VectorisedView - inSize := tc.vv.Size() - copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) - if got, want := copied, len(tc.wantBytes); got != want { - t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) - } - if got, want := string(readTo.ToView()), tc.wantBytes; got != want { - t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) - } - if got, want := tc.vv.Size(), inSize-copied; got != want { - t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) - } - if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { - t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) - } - }) - } -} - -func TestVVReadTo(t *testing.T) { - for _, tc := range createReadToTestCases() { - t.Run(tc.comment, func(t *testing.T) { - b := make([]byte, tc.bytesToRead) - dst := tcpip.SliceWriter(b) - origSize := tc.vv.Size() - copied, err := tc.vv.ReadTo(&dst, false /* peek */) - if err != nil && err != io.ErrShortWrite { - t.Errorf("got ReadTo(&dst, false) = (_, %s); want nil or io.ErrShortWrite", err) - } - if got, want := copied, len(tc.wantBytes); got != want { - t.Errorf("got ReadTo(&dst, false) = (%d, _); want %d", got, want) - } - if got, want := string(b[:copied]), tc.wantBytes; got != want { - t.Errorf("got dst = %q, want %q", got, want) - } - if got, want := tc.vv.Size(), origSize-copied; got != want { - t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) - } - if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { - t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) - } - }) - } -} - -func TestVVReadToPeek(t *testing.T) { - for _, tc := range createReadToTestCases() { - t.Run(tc.comment, func(t *testing.T) { - b := make([]byte, tc.bytesToRead) - dst := tcpip.SliceWriter(b) - origSize := tc.vv.Size() - origData := string(tc.vv.ToView()) - copied, err := tc.vv.ReadTo(&dst, true /* peek */) - if err != nil && err != io.ErrShortWrite { - t.Errorf("got ReadTo(&dst, true) = (_, %s); want nil or io.ErrShortWrite", err) - } - if got, want := copied, len(tc.wantBytes); got != want { - t.Errorf("got ReadTo(&dst, true) = (%d, _); want %d", got, want) - } - if got, want := string(b[:copied]), tc.wantBytes; got != want { - t.Errorf("got dst = %q, want %q", got, want) - } - // Expect tc.vv is unchanged. - if got, want := tc.vv.Size(), origSize; got != want { - t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) - } - if got, want := string(tc.vv.ToView()), origData; got != want { - t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) - } - }) - } -} - -func TestVVRead(t *testing.T) { - testCases := []struct { - comment string - vv buffer.VectorisedView - bytesToRead int - readBytes string - leftBytes string - wantError bool - }{ - { - comment: "large VV, short read", - vv: vv(30, "012345678901234567890123456789"), - bytesToRead: 10, - readBytes: "0123456789", - leftBytes: "01234567890123456789", - }, - { - comment: "largeVV, multiple buffers, short read", - vv: vv(13, "123", "345", "567", "8910"), - bytesToRead: 6, - readBytes: "123345", - leftBytes: "5678910", - }, - { - comment: "smallVV, large read", - vv: vv(3, "1", "2", "3"), - bytesToRead: 10, - readBytes: "123", - leftBytes: "", - }, - { - comment: "smallVV, large read", - vv: vv(1, "1"), - bytesToRead: 10, - readBytes: "1", - leftBytes: "", - }, - { - comment: "emptyVV, large read", - vv: vv(0, ""), - bytesToRead: 10, - readBytes: "", - wantError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.comment, func(t *testing.T) { - readTo := buffer.NewView(tc.bytesToRead) - inSize := tc.vv.Size() - copied, err := tc.vv.Read(readTo) - if !tc.wantError && err != nil { - t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) - } - readTo = readTo[:copied] - if got, want := copied, len(tc.readBytes); got != want { - t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) - } - if got, want := string(readTo), tc.readBytes; got != want { - t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) - } - if got, want := tc.vv.Size(), inSize-copied; got != want { - t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) - } - if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { - t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) - } - }) - } -} - -var pullUpTestCases = []struct { - comment string - in buffer.VectorisedView - count int - want []byte - result buffer.VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, buffer.View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} - -func TestToVectorisedView(t *testing.T) { - testCases := []struct { - in buffer.View - want buffer.VectorisedView - }{ - {nil, buffer.VectorisedView{}}, - {buffer.View{}, buffer.VectorisedView{}}, - {buffer.View{'a'}, buffer.NewVectorisedView(1, []buffer.View{{'a'}})}, - } - for _, tc := range testCases { - if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) - } - } -} - -func TestAppendView(t *testing.T) { - testCases := []struct { - vv buffer.VectorisedView - in buffer.View - want buffer.VectorisedView - }{ - {vv(0), nil, vv(0)}, - {vv(0), v(""), vv(0)}, - {vv(4, "abcd"), nil, vv(4, "abcd")}, - {vv(4, "abcd"), v(""), vv(4, "abcd")}, - {vv(4, "abcd"), v("e"), vv(5, "abcd", "e")}, - } - for _, tc := range testCases { - tc.vv.AppendView(tc.in) - if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) - } - } -} - -func TestAppendViews(t *testing.T) { - testCases := []struct { - vv buffer.VectorisedView - in []buffer.View - want buffer.VectorisedView - }{ - {vv(0), nil, vv(0)}, - {vv(0), []buffer.View{}, vv(0)}, - {vv(0), []buffer.View{v("")}, vv(0, "")}, - {vv(4, "abcd"), nil, vv(4, "abcd")}, - {vv(4, "abcd"), []buffer.View{}, vv(4, "abcd")}, - {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")}, - {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")}, - {vv(4, "abcd"), []buffer.View{v("e")}, vv(5, "abcd", "e")}, - {vv(4, "abcd"), []buffer.View{v("e"), v("fg")}, vv(7, "abcd", "e", "fg")}, - {vv(4, "abcd"), []buffer.View{v(""), v("fg")}, vv(6, "abcd", "", "fg")}, - } - for _, tc := range testCases { - tc.vv.AppendViews(tc.in) - if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) - } - } -} - -func TestMemSize(t *testing.T) { - const perViewCap = 128 - views := make([]buffer.View, 2, 32) - views[0] = make(buffer.View, 10, perViewCap) - views[1] = make(buffer.View, 20, perViewCap) - vv := buffer.NewVectorisedView(30, views) - want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap - if got := vv.MemSize(); got != want { - t.Errorf("vv.MemSize() = %d, want %d", got, want) - } -} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD deleted file mode 100644 index c984470e6..000000000 --- a/pkg/tcpip/checker/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "checker", - testonly = 1, - srcs = ["checker.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go deleted file mode 100644 index e0dfe5813..000000000 --- a/pkg/tcpip/checker/checker.go +++ /dev/null @@ -1,1625 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checker provides helper functions to check networking packets for -// validity. -package checker - -import ( - "encoding/binary" - "reflect" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// NetworkChecker is a function to check a property of a network packet. -type NetworkChecker func(*testing.T, []header.Network) - -// TransportChecker is a function to check a property of a transport packet. -type TransportChecker func(*testing.T, header.Transport) - -// ControlMessagesChecker is a function to check a property of ancillary data. -type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) - -// IPv4 checks the validity and properties of the given IPv4 packet. It is -// expected to be used in conjunction with other network checkers for specific -// properties. For example, to check the source and destination address, one -// would call: -// -// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) -func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv4 := header.IPv4(b) - - if !ipv4.IsValid(len(b)) { - t.Fatalf("Not a valid IPv4 packet: %x", ipv4) - } - - if !ipv4.IsChecksumValid() { - t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) - } - - for _, f := range checkers { - f(t, []header.Network{ipv4}) - } - if t.Failed() { - t.FailNow() - } -} - -// IPv6 checks the validity and properties of the given IPv6 packet. The usage -// is similar to IPv4. -func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Fatalf("Not a valid IPv6 packet: %x", ipv6) - } - - for _, f := range checkers { - f(t, []header.Network{ipv6}) - } - if t.Failed() { - t.FailNow() - } -} - -// SrcAddr creates a checker that checks the source address. -func SrcAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].SourceAddress(); a != addr { - t.Errorf("Bad source address, got %v, want %v", a, addr) - } - } -} - -// DstAddr creates a checker that checks the destination address. -func DstAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].DestinationAddress(); a != addr { - t.Errorf("Bad destination address, got %v, want %v", a, addr) - } - } -} - -// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). -func TTL(ttl uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - var v uint8 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TTL() - case header.IPv6: - v = ip.HopLimit() - case *ipv6HeaderWithExtHdr: - v = ip.HopLimit() - default: - t.Fatalf("unrecognized header type %T for TTL evaluation", ip) - } - if v != ttl { - t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) - } - } -} - -// IPFullLength creates a checker for the full IP packet length. The -// expected size is checked against both the Total Length in the -// header and the number of bytes received. -func IPFullLength(packetLength uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - var v uint16 - var l uint16 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TotalLength() - l = uint16(len(ip)) - case header.IPv6: - v = ip.PayloadLength() + header.IPv6FixedHeaderSize - l = uint16(len(ip)) - default: - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) - } - if l != packetLength { - t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) - } - if v != packetLength { - t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) - } - } -} - -// IPv4HeaderLength creates a checker that checks the IPv4 Header length. -func IPv4HeaderLength(headerLength int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - switch ip := h[0].(type) { - case header.IPv4: - if hl := ip.HeaderLength(); hl != uint8(headerLength) { - t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) - } - default: - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) - } - } -} - -// PayloadLen creates a checker that checks the payload length. -func PayloadLen(payloadLength int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if l := len(h[0].Payload()); l != payloadLength { - t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) - } - } -} - -// IPPayload creates a checker that checks the payload. -func IPPayload(payload []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - got := h[0].Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(got) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(payload, got); diff != "" { - t.Errorf("payload mismatch (-want +got):\n%s", diff) - } - } -} - -// IPv4Options returns a checker that checks the options in an IPv4 packet. -func IPv4Options(want header.IPv4Options) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - ip, ok := h[0].(header.IPv4) - if !ok { - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) - } - options := ip.Options() - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(options) == 0 { - return - } - if diff := cmp.Diff(want, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - } -} - -// IPv4RouterAlert returns a checker that checks that the RouterAlert option is -// set in an IPv4 packet. -func IPv4RouterAlert() NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - ip, ok := h[0].(header.IPv4) - if !ok { - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) - } - iterator := ip.Options().MakeIterator() - for { - opt, done, err := iterator.Next() - if err != nil { - t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) - } - if done { - break - } - if opt.Type() != header.IPv4OptionRouterAlertType { - continue - } - want := [header.IPv4OptionRouterAlertLength]byte{ - byte(header.IPv4OptionRouterAlertType), - header.IPv4OptionRouterAlertLength, - header.IPv4OptionRouterAlertValue, - header.IPv4OptionRouterAlertValue, - } - if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { - t.Errorf("router alert option mismatch (-want +got):\n%s", diff) - } - return - } - t.Errorf("failed to find router alert option in %v", ip.Options()) - } -} - -// FragmentOffset creates a checker that checks the FragmentOffset field. -func FragmentOffset(offset uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this for IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) - } - } - } -} - -// FragmentFlags creates a checker that checks the fragment flags field. -func FragmentFlags(flags uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this for IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) - } - } - } -} - -// ReceiveTClass creates a checker that checks the TCLASS field in -// ControlMessages. -func ReceiveTClass(want uint32) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTClass { - t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) - } else if got := cm.TClass; got != want { - t.Errorf("got cm.TClass = %d, want %d", got, want) - } - } -} - -// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. -func ReceiveTOS(want uint8) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTOS { - t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) - } else if got := cm.TOS; got != want { - t.Errorf("got cm.TOS = %d, want %d", got, want) - } - } -} - -// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in -// ControlMessages. -func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasIPPacketInfo { - t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) - } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { - t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) - } - } -} - -// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress -// field in ControlMessages. -func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasOriginalDstAddress { - t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) - } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { - t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) - } - } -} - -// TOS creates a checker that checks the TOS field. -func TOS(tos uint8, label uint32) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) - } - } -} - -// Raw creates a checker that checks the bytes of payload. -// The checker always checks the payload of the last network header. -// For instance, in case of IPv6 fragments, the payload that will be checked -// is the one containing the actual data that the packet is carrying, without -// the bytes added by the IPv6 fragmentation. -func Raw(want []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// IPv6Fragment creates a checker that validates an IPv6 fragment. -func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) - } - - ipv6Frag := header.IPv6Fragment(h[0].Payload()) - if !ipv6Frag.IsValid() { - t.Error("Not a valid IPv6 fragment") - } - - for _, f := range checkers { - f(t, []header.Network{h[0], ipv6Frag}) - } - if t.Failed() { - t.FailNow() - } - } -} - -// TCP creates a checker that checks that the transport protocol is TCP and -// potentially additional transport header fields. -func TCP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - first := h[0] - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) - } - - tcp := header.TCP(last.Payload()) - payload := tcp.Payload() - payloadChecksum := header.Checksum(payload, 0) - if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { - t.Errorf("Bad checksum, got = %d", tcp.Checksum()) - } - - // Run the transport checkers. - for _, f := range checkers { - f(t, tcp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// UDP creates a checker that checks that the transport protocol is UDP and -// potentially additional transport header fields. -func UDP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) - } - - udp := header.UDP(last.Payload()) - for _, f := range checkers { - f(t, udp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// SrcPort creates a checker that checks the source port. -func SrcPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got = %d, want = %d", p, port) - } - } -} - -// DstPort creates a checker that checks the destination port. -func DstPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got = %d, want = %d", p, port) - } - } -} - -// NoChecksum creates a checker that checks if the checksum is zero. -func NoChecksum(noChecksum bool) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - udp, ok := h.(header.UDP) - if !ok { - t.Fatalf("UDP header not found in h: %T", h) - } - - if b := udp.Checksum() == 0; b != noChecksum { - t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) - } - } -} - -// TCPSeqNum creates a checker that checks the sequence number. -func TCPSeqNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) - } - } -} - -// TCPAckNum creates a checker that checks the ack number. -func TCPAckNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got = %d, want = %d", s, seq) - } - } -} - -// TCPWindow creates a checker that checks the tcp window. -func TCPWindow(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in hdr : %T", h) - } - - if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got %d, want %d", w, window) - } - } -} - -// TCPWindowGreaterThanEq creates a checker that checks that the TCP window -// is greater than or equal to the provided value. -func TCPWindowGreaterThanEq(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if w := tcp.WindowSize(); w < window { - t.Errorf("Bad window, got %d, want > %d", w, window) - } - } -} - -// TCPWindowLessThanEq creates a checker that checks that the tcp window -// is less than or equal to the provided value. -func TCPWindowLessThanEq(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if w := tcp.WindowSize(); w > window { - t.Errorf("Bad window, got %d, want < %d", w, window) - } - } -} - -// TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags header.TCPFlags) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if got := tcp.Flags(); got != flags { - t.Errorf("got tcp.Flags() = %s, want %s", got, flags) - } - } -} - -// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the -// given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if got := tcp.Flags(); (got & mask) != (flags & mask) { - t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask) - } - } -} - -// TCPSynOptions creates a checker that checks the presence of TCP options in -// SYN segments. -// -// If wndscale is negative, the window scale option must not be present. -func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundMSS := false - foundWS := false - foundTS := false - foundSACKPermitted := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionMSS: - v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if wantOpts.MSS != v { - t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) - } - foundMSS = true - i += 4 - case header.TCPOptionWS: - if wantOpts.WS < 0 { - t.Error("WS present when it shouldn't be") - } - v := int(opts[i+2]) - if v != wantOpts.WS { - t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) - } - foundWS = true - i += 3 - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = uint32(0) - if tcp.Flags()&header.TCPFlagAck != 0 { - // If the syn is an SYN-ACK then read - // the tsEcr value as well. - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - foundTS = true - i += 10 - case header.TCPOptionSACKPermitted: - if i+1 >= limit { - t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) - } - if opts[i+1] != 2 { - t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) - } - foundSACKPermitted = true - i += 2 - - default: - i += int(opts[i+1]) - } - } - - if !foundMSS { - t.Errorf("MSS option not found. Options: %x", opts) - } - - if !foundWS && wantOpts.WS >= 0 { - t.Errorf("WS option not found. Options: %x", opts) - } - if wantOpts.TS && !foundTS { - t.Errorf("TS option not found. Options: %x", opts) - } - if foundTS && tsVal == 0 { - t.Error("TS option specified but the timestamp value is zero") - } - if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) - } - if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Errorf("SACKPermitted option not found. Options: %x", opts) - } - } -} - -// TCPTimestampChecker creates a checker that validates that a TCP segment has a -// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and -// wantTSEcr values with those in the TCP segment (if present). -// -// If wantTSVal or wantTSEcr is zero then the corresponding comparison is -// skipped. -func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundTS := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - foundTS = true - i += 10 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return - } - l := int(opts[i+1]) - if i < 2 || i+l > limit { - return - } - i += l - } - } - - if wantTS != foundTS { - t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) - } - if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) - } - if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) - } - } -} - -// TCPSACKBlockChecker creates a checker that verifies that the segment does -// contain the specified SACK blocks in the TCP options. -func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - var gotSACKBlocks []header.SACKBlock - - opts := tcp.Options() - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block. - t.Errorf("malformed SACK option in options: %v", opts) - } - sackOptionLen := int(opts[i+1]) - if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { - // Malformed SACK block. - t.Errorf("malformed SACK option length in options: %v", opts) - } - numBlocks := sackOptionLen / 8 - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(opts[i+2+j*8:]) - end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) - gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - i += sackOptionLen - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - break - } - l := int(opts[i+1]) - if l < 2 || i+l > limit { - break - } - i += l - } - } - - if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) - } - } -} - -// Payload creates a checker that checks the payload. -func Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 -// and potentially additional ICMPv4 header fields. -func ICMPv4(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) - } - - icmp := header.ICMPv4(last.Payload()) - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv4Type creates a checker that checks the ICMPv4 Type field. -func ICMPv4Type(want header.ICMPv4Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want header.ICMPv4Code) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. -func ICMPv4Ident(want uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Ident(); got != want { - t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. -func ICMPv4Seq(want uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Sequence(); got != want { - t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. -func ICMPv4Pointer(want uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Pointer(); got != want { - t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. -// This assumes that the payload exactly makes up the rest of the slice. -func ICMPv4Checksum() TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - heldChecksum := icmpv4.Checksum() - icmpv4.SetChecksum(0) - newChecksum := ^header.Checksum(icmpv4, 0) - icmpv4.SetChecksum(heldChecksum) - if heldChecksum != newChecksum { - t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) - } - } -} - -// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. -func ICMPv4Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - payload := icmpv4.Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(want, payload); diff != "" { - t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) - } - } -} - -// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and -// potentially additional ICMPv6 header fields. -// -// ICMPv6 will validate the checksum field before calling checkers. -func ICMPv6(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) - } - - icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: last.SourceAddress(), - Dst: last.DestinationAddress(), - }); got != want { - t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv6Type creates a checker that checks the ICMPv6 Type field. -func ICMPv6Type(want header.ICMPv6Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want header.ICMPv6Code) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific -// field. -func ICMPv6TypeSpecific(want uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.TypeSpecific(); got != want { - t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. -func ICMPv6Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - payload := icmpv6.Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(want, payload); diff != "" { - t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) - } - } -} - -// MLD creates a checker that checks that the packet contains a valid MLD -// message for type of mldType, with potentially additional checks specified by -// checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// MLD message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.MessageBody()); got < minSize { - t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay -// field of a MLD message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid MLD message as far as the size is concerned. -func MLDMaxRespDelay(want time.Duration) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.MLD(icmp.MessageBody()) - - if got := ns.MaximumResponseDelay(); got != want { - t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) - } - } -} - -// MLDMulticastAddress creates a checker that checks the Multicast Address -// field of a MLD message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid MLD message as far as the size is concerned. -func MLDMulticastAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.MLD(icmp.MessageBody()) - - if got := ns.MulticastAddress(); got != want { - t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) - } - } -} - -// NDP creates a checker that checks that the packet contains a valid NDP -// message for type of ty, with potentially additional checks specified by -// checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDP message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.MessageBody()); got < minSize { - t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// NDPNS creates a checker that checks that the packet contains a valid NDP -// Neighbor Solicitation message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the message is concerned. The values -// within the message are up to checkers to validate. -func NDPNS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) -} - -// NDPNSTargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborSolicit. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSTargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - - if got := ns.TargetAddress(); got != want { - t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) - } - } -} - -// NDPNA creates a checker that checks that the packet contains a valid NDP -// Neighbor Advertisement message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNA message as far as the size of the message is concerned. The values -// within the message are up to checkers to validate. -func NDPNA(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) -} - -// NDPNATargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborAdvert. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNATargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - - if got := na.TargetAddress(); got != want { - t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) - } - } -} - -// NDPNASolicitedFlag creates a checker that checks the Solicited field of -// a header.NDPNeighborAdvert. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNASolicitedFlag(want bool) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - - if got := na.SolicitedFlag(); got != want { - t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) - } - } -} - -// ndpOptions checks that optsBuf only contains opts. -func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { - t.Helper() - - it, err := optsBuf.Iter(true) - if err != nil { - t.Errorf("optsBuf.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - t.Fatalf("unexpected error when iterating over NDP options: %s", err) - } - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - case header.NDPTargetLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - case header.NDPNonceOption: - gotOpt, ok := opt.(header.NDPNonceOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" { - t.Errorf("nonce mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } -} - -// NDPNAOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNAOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - ndpOptions(t, na.Options(), opts) - } -} - -// NDPNSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ndpOptions(t, ns.Options(), opts) - } -} - -// NDPRS creates a checker that checks that the packet contains a valid NDP -// Router Solicitation message (as per the raw wire format). -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPRS as far as the size of the message is concerned. The values within the -// message are up to checkers to validate. -func NDPRS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) -} - -// NDPRSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Router Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPRS message as far as the size is concerned. -func NDPRSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.MessageBody()) - ndpOptions(t, rs.Options(), opts) - } -} - -// IGMP checks the validity and properties of the given IGMP packet. It is -// expected to be used in conjunction with other IGMP transport checkers for -// specific properties. -func IGMP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) - } - - igmp := header.IGMP(last.Payload()) - for _, f := range checkers { - f(t, igmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// IGMPType creates a checker that checks the IGMP Type field. -func IGMPType(want header.IGMPType) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.Type(); got != want { - t.Errorf("got igmp.Type() = %d, want = %d", got, want) - } - } -} - -// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. -func IGMPMaxRespTime(want time.Duration) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.MaxRespTime(); got != want { - t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) - } - } -} - -// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. -func IGMPGroupAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.GroupAddress(); got != want { - t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) - } - } -} - -// IPv6ExtHdrChecker is a function to check an extension header. -type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) - -// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. -func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("not a valid IPv6 packet") - return - } - - payloadIterator := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), - buffer.View(ipv6.Payload()).ToVectorisedView(), - ) - - var rawPayloadHeader header.IPv6RawPayloadHeader - for { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) - return - } - r, ok := h.(header.IPv6RawPayloadHeader) - if ok { - rawPayloadHeader = r - break - } - } - - networkHeader := ipv6HeaderWithExtHdr{ - IPv6: ipv6, - transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), - payload: rawPayloadHeader.Buf.ToView(), - } - - for _, checker := range checkers { - checker(t, []header.Network{&networkHeader}) - } -} - -// IPv6ExtHdr checks for the presence of extension headers. -// -// All the extension headers in headers will be checked exhaustively in the -// order provided. -func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) - if !ok { - t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) - return - } - - payloadIterator := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), - buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), - ) - - for _, check := range headers { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) - return - } - check(t, h) - } - // Validate we consumed all headers. - // - // The next one over should be a raw payload and then iterator should - // terminate. - wantDone := false - for { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done != wantDone { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) - return - } - if done { - break - } - if _, ok := h.(header.IPv6RawPayloadHeader); !ok { - t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) - continue - } - wantDone = true - } - } -} - -var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) - -// ipv6HeaderWithExtHdr provides a header.Network implementation that takes -// extension headers into consideration, which is not the case with vanilla -// header.IPv6. -type ipv6HeaderWithExtHdr struct { - header.IPv6 - transport tcpip.TransportProtocolNumber - payload []byte -} - -// TransportProtocol implements header.Network. -func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { - return h.transport -} - -// Payload implements header.Network. -func (h *ipv6HeaderWithExtHdr) Payload() []byte { - return h.payload -} - -// IPv6ExtHdrOptionChecker is a function to check an extension header option. -type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) - -// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop -// extension header and validates the containing options with checkers. -// -// checkers must exhaustively contain all the expected options. -func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { - return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { - t.Helper() - - hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) - if !ok { - t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) - return - } - optionsIterator := hbh.Iter() - for _, f := range checkers { - opt, done, err := optionsIterator.Next() - if err != nil { - t.Errorf("optionsIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) - } - f(t, opt) - } - // Validate all options were consumed. - for { - opt, done, err := optionsIterator.Next() - if err != nil { - t.Errorf("optionsIterator.Next(): %s", err) - return - } - if !done { - t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) - } - if done { - break - } - } - } -} - -// IPv6RouterAlert validates that an extension header option is the RouterAlert -// option and matches on its value. -func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { - return func(t *testing.T, opt header.IPv6ExtHdrOption) { - routerAlert, ok := opt.(*header.IPv6RouterAlertOption) - if !ok { - t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) - return - } - if routerAlert.Value != want { - t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) - } - } -} - -// IPv6UnknownOption validates that an extension header option is the -// unknown header option. -func IPv6UnknownOption() IPv6ExtHdrOptionChecker { - return func(t *testing.T, opt header.IPv6ExtHdrOption) { - _, ok := opt.(*header.IPv6UnknownExtHdrOption) - if !ok { - t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt) - } - } -} - -// IgnoreCmpPath returns a cmp.Option that ignores listed field paths. -func IgnoreCmpPath(paths ...string) cmp.Option { - ignores := map[string]struct{}{} - for _, path := range paths { - ignores[path] = struct{}{} - } - return cmp.FilterPath(func(path cmp.Path) bool { - _, ok := ignores[path.String()] - return ok - }, cmp.Ignore()) -} diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD deleted file mode 100644 index bb9d44aff..000000000 --- a/pkg/tcpip/faketime/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "faketime", - srcs = ["faketime.go"], - visibility = ["//visibility:public"], - deps = ["//pkg/tcpip"], -) - -go_test( - name = "faketime_test", - size = "small", - srcs = [ - "faketime_test.go", - ], - deps = [ - "//pkg/tcpip/faketime", - ], -) diff --git a/pkg/tcpip/faketime/faketime_state_autogen.go b/pkg/tcpip/faketime/faketime_state_autogen.go new file mode 100644 index 000000000..3de72f27d --- /dev/null +++ b/pkg/tcpip/faketime/faketime_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package faketime diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go deleted file mode 100644 index fd2bb470a..000000000 --- a/pkg/tcpip/faketime/faketime_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package faketime_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip/faketime" -) - -func TestManualClockAdvance(t *testing.T) { - const timeout = time.Millisecond - clock := faketime.NewManualClock() - start := clock.NowMonotonic() - clock.Advance(timeout) - if got, want := clock.NowMonotonic().Sub(start), timeout; got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestManualClockAfterFunc(t *testing.T) { - const ( - timeout1 = time.Millisecond // timeout for counter1 - timeout2 = 2 * time.Millisecond // timeout for counter2 - ) - tests := []struct { - name string - advance time.Duration - wantCounter1 int - wantCounter2 int - }{ - { - name: "before timeout1", - advance: timeout1 - 1, - wantCounter1: 0, - wantCounter2: 0, - }, - { - name: "timeout1", - advance: timeout1, - wantCounter1: 1, - wantCounter2: 0, - }, - { - name: "timeout2", - advance: timeout2, - wantCounter1: 1, - wantCounter2: 1, - }, - { - name: "after timeout2", - advance: timeout2 + 1, - wantCounter1: 1, - wantCounter2: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - counter1 := 0 - counter2 := 0 - clock.AfterFunc(timeout1, func() { - counter1++ - }) - clock.AfterFunc(timeout2, func() { - counter2++ - }) - start := clock.NowMonotonic() - clock.Advance(test.advance) - if got, want := counter1, test.wantCounter1; got != want { - t.Errorf("got counter1 = %d, want = %d", got, want) - } - if got, want := counter2, test.wantCounter2; got != want { - t.Errorf("got counter2 = %d, want = %d", got, want) - } - if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want { - t.Errorf("got elapsed = %d, want = %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD deleted file mode 100644 index ff2719291..000000000 --- a/pkg/tcpip/hash/jenkins/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "jenkins", - srcs = ["jenkins.go"], - visibility = ["//visibility:public"], -) - -go_test( - name = "jenkins_test", - size = "small", - srcs = [ - "jenkins_test.go", - ], - library = ":jenkins", -) diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go new file mode 100644 index 000000000..216cc5a2e --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package jenkins diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go deleted file mode 100644 index 4c78b5808..000000000 --- a/pkg/tcpip/hash/jenkins/jenkins_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package jenkins - -import ( - "bytes" - "encoding/binary" - "hash" - "hash/fnv" - "math" - "testing" -) - -func TestGolden32(t *testing.T) { - var golden32 = []struct { - out []byte - in string - }{ - {[]byte{0x00, 0x00, 0x00, 0x00}, ""}, - {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"}, - {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"}, - {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"}, - } - - hash := New32() - - for _, g := range golden32 { - hash.Reset() - done, error := hash.Write([]byte(g.in)) - if error != nil { - t.Fatalf("write error: %s", error) - } - if done != len(g.in) { - t.Fatalf("wrote only %d out of %d bytes", done, len(g.in)) - } - if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) { - t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out) - } - } -} - -func TestIntegrity32(t *testing.T) { - data := []byte{'1', '2', 3, 4, 5} - - h := New32() - h.Write(data) - sum := h.Sum(nil) - - if size := h.Size(); size != len(sum) { - t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum)) - } - - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data[:2]) - h.Write(data[2:]) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a) - } - - sum32 := h.(hash.Hash32).Sum32() - if sum32 != binary.BigEndian.Uint32(sum) { - t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32) - } -} - -func BenchmarkJenkins32KB(b *testing.B) { - h := New32() - - b.SetBytes(1024) - data := make([]byte, 1024) - for i := range data { - data[i] = byte(i) - } - in := make([]byte, 0, h.Size()) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - h.Reset() - h.Write(data) - h.Sum(in) - } -} - -func BenchmarkFnv32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - - h := fnv.New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - c := 0 - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - if c == 0 { - b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N) - } - c++ - } - } - if c > 0 { - b.Logf("Unbalanced buckets: %d", c) - } - } -} - -func BenchmarkSum32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := Sum32(0) - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} - -func BenchmarkNew32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD deleted file mode 100644 index 01240f5d0..000000000 --- a/pkg/tcpip/header/BUILD +++ /dev/null @@ -1,76 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "header", - srcs = [ - "arp.go", - "checksum.go", - "eth.go", - "gue.go", - "icmpv4.go", - "icmpv6.go", - "igmp.go", - "interfaces.go", - "ipv4.go", - "ipv6.go", - "ipv6_extension_headers.go", - "ipv6_fragment.go", - "mld.go", - "ndp_neighbor_advert.go", - "ndp_neighbor_solicit.go", - "ndp_options.go", - "ndp_router_advert.go", - "ndp_router_solicit.go", - "ndpoptionidentifier_string.go", - "tcp.go", - "udp.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/seqnum", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "header_x_test", - size = "small", - srcs = [ - "checksum_test.go", - "igmp_test.go", - "ipv4_test.go", - "ipv6_test.go", - "ipversion_test.go", - "tcp_test.go", - ], - deps = [ - ":header", - "//pkg/rand", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/testutil", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "header_test", - size = "small", - srcs = [ - "eth_test.go", - "ipv6_extension_headers_test.go", - "mld_test.go", - "ndp_test.go", - ], - library = ":header", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/testutil", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go deleted file mode 100644 index 3445511f4..000000000 --- a/pkg/tcpip/header/checksum_test.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package header provides the implementation of the encoding and decoding of -// network protocol headers. -package header_test - -import ( - "bytes" - "fmt" - "math/rand" - "sync" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestChecksumer(t *testing.T) { - testCases := []struct { - name string - data [][]byte - want uint16 - }{ - { - name: "empty", - want: 0, - }, - { - name: "OneOddView", - data: [][]byte{ - []byte{1, 9, 0, 5, 4}, - }, - want: 1294, - }, - { - name: "TwoOddViews", - data: [][]byte{ - []byte{1, 9, 0, 5, 4}, - []byte{4, 3, 7, 1, 2, 123}, - }, - want: 33819, - }, - { - name: "OneEvenView", - data: [][]byte{ - []byte{1, 9, 0, 5}, - }, - want: 270, - }, - { - name: "TwoEvenViews", - data: [][]byte{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0}), - buffer.NewViewFromBytes([]byte{9, 0, 5, 4}), - }, - want: 30981, - }, - { - name: "ThreeViews", - data: [][]byte{ - []byte{77, 11, 33, 0, 55, 44}, - []byte{98, 1, 9, 0, 5, 4}, - []byte{4, 3, 7, 1, 2, 123, 99}, - }, - want: 34236, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var all bytes.Buffer - var c header.Checksumer - for _, b := range tc.data { - c.Add(b) - // Append to the buffer. We will check the checksum as a whole later. - if _, err := all.Write(b); err != nil { - t.Fatalf("all.Write(b) = _, %s; want _, nil", err) - } - } - if got, want := c.Checksum(), tc.want; got != want { - t.Errorf("c.Checksum() = %d, want %d", got, want) - } - if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { - t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) - } - }) - } -} - -func TestChecksum(t *testing.T) { - var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} - type testCase struct { - buf []byte - initial uint16 - csumOrig uint16 - csumNew uint16 - } - testCases := make([]testCase, 100000) - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for i := range testCases { - testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) - testCases[i].initial = uint16(rnd.Intn(65536)) - rnd.Read(testCases[i].buf) - } - - for i := range testCases { - testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) - testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) - if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { - t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) - } - } -} - -func BenchmarkChecksum(b *testing.B) { - var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} - - checkSumImpls := []struct { - fn func([]byte, uint16) uint16 - name string - }{ - {header.ChecksumOld, fmt.Sprintf("checksum_old")}, - {header.Checksum, fmt.Sprintf("checksum")}, - } - - for _, csumImpl := range checkSumImpls { - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for _, bufSz := range bufSizes { - b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { - tc := struct { - buf []byte - initial uint16 - csum uint16 - }{ - buf: make([]byte, bufSz), - initial: uint16(rnd.Intn(65536)), - } - rnd.Read(tc.buf) - b.ResetTimer() - for i := 0; i < b.N; i++ { - tc.csum = csumImpl.fn(tc.buf, tc.initial) - } - }) - } - } -} - -func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { - // icmpChecksum should not do any modifications of the header to - // calculate its checksum. Let's call it from a few go-routines and the - // race detector will trigger a warning if there are any concurrent - // read/write accesses. - - const concurrency = 5 - start := make(chan int) - ready := make(chan bool, concurrency) - var wg sync.WaitGroup - wg.Add(concurrency) - defer wg.Wait() - - for i := 0; i < concurrency; i++ { - go func() { - defer wg.Done() - - ready <- true - <-start - - if got := headerChecksum(); want != got { - t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) - } - if got := icmpChecksum(); want != got { - t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) - } - }() - } - for i := 0; i < concurrency; i++ { - <-ready - } - close(start) -} - -func TestICMPv4Checksum(t *testing.T) { - rnd := rand.New(rand.NewSource(42)) - - h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) - if _, err := rnd.Read(h); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - h.SetChecksum(0) - - buf := make([]byte, 13) - if _, err := rnd.Read(buf); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - vv := buffer.NewVectorisedView(len(buf), []buffer.View{ - buffer.NewViewFromBytes(buf[:5]), - buffer.NewViewFromBytes(buf[5:]), - }) - - want := header.Checksum(vv.ToView(), 0) - want = ^header.Checksum(h, want) - h.SetChecksum(want) - - testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0)) - }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) -} - -func TestICMPv6Checksum(t *testing.T) { - rnd := rand.New(rand.NewSource(42)) - - h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) - if _, err := rnd.Read(h); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - h.SetChecksum(0) - - buf := make([]byte, 13) - if _, err := rnd.Read(buf); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - vv := buffer.NewVectorisedView(len(buf), []buffer.View{ - buffer.NewViewFromBytes(buf[:7]), - buffer.NewViewFromBytes(buf[7:10]), - buffer.NewViewFromBytes(buf[10:]), - }) - - dst := header.IPv6Loopback - src := header.IPv6Loopback - - want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) - want = header.Checksum(vv.ToView(), want) - want = ^header.Checksum(h, want) - h.SetChecksum(want) - - testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: h, - Src: src, - Dst: dst, - PayloadCsum: header.ChecksumVV(vv, 0), - PayloadLen: vv.Size(), - }) - }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) -} - -func randomAddress(size int) tcpip.Address { - s := make([]byte, size) - for i := 0; i < size; i++ { - s[i] = byte(rand.Uint32()) - } - return tcpip.Address(s) -} - -func TestChecksummableNetworkUpdateAddress(t *testing.T) { - tests := []struct { - name string - update func(header.IPv4, tcpip.Address) - }{ - { - name: "SetSourceAddressWithChecksumUpdate", - update: header.IPv4.SetSourceAddressWithChecksumUpdate, - }, - { - name: "SetDestinationAddressWithChecksumUpdate", - update: header.IPv4.SetDestinationAddressWithChecksumUpdate, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for i := 0; i < 1000; i++ { - var origBytes [header.IPv4MinimumSize]byte - header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ - TOS: 1, - TotalLength: header.IPv4MinimumSize, - ID: 2, - Flags: 3, - FragmentOffset: 4, - TTL: 5, - Protocol: 6, - Checksum: 0, - SrcAddr: randomAddress(header.IPv4AddressSize), - DstAddr: randomAddress(header.IPv4AddressSize), - }) - - addr := randomAddress(header.IPv4AddressSize) - - bytesCopy := origBytes - h := header.IPv4(bytesCopy[:]) - origXSum := h.CalculateChecksum() - h.SetChecksum(^origXSum) - - test.update(h, addr) - got := ^h.Checksum() - h.SetChecksum(0) - want := h.CalculateChecksum() - if got != want { - t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) - } - } - }) - } -} - -func TestChecksummableTransportUpdatePort(t *testing.T) { - // The fields in the pseudo header is not tested here so we just use 0. - const pseudoHeaderXSum = 0 - - tests := []struct { - name string - transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) - proto tcpip.TransportProtocolNumber - }{ - { - name: "TCP", - transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { - h := header.TCP(make([]byte, header.TCPMinimumSize)) - h.Encode(&header.TCPFields{ - SrcPort: src, - DstPort: dst, - SeqNum: 1, - AckNum: 2, - DataOffset: header.TCPMinimumSize, - Flags: 3, - WindowSize: 4, - Checksum: 0, - UrgentPointer: 5, - }) - h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) - return h, h.CalculateChecksum - }, - proto: header.TCPProtocolNumber, - }, - { - name: "UDP", - transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { - h := header.UDP(make([]byte, header.UDPMinimumSize)) - h.Encode(&header.UDPFields{ - SrcPort: src, - DstPort: dst, - Length: 0, - Checksum: 0, - }) - h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) - return h, h.CalculateChecksum - }, - proto: header.UDPProtocolNumber, - }, - } - - for i := 0; i < 1000; i++ { - origSrcPort := uint16(rand.Uint32()) - origDstPort := uint16(rand.Uint32()) - newPort := uint16(rand.Uint32()) - - t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range []struct { - name string - update func(header.ChecksummableTransport) - }{ - { - name: "Source port", - update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, - }, - { - name: "Destination port", - update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, - }, - } { - t.Run(subTest.name, func(t *testing.T) { - h, calcXSum := test.transportHdr(origSrcPort, origDstPort) - subTest.update(h) - // TCP and UDP hold the 1s complement of the fully calculated - // checksum. - got := ^h.Checksum() - h.SetChecksum(0) - - if want := calcXSum(pseudoHeaderXSum); got != want { - h, _ := test.transportHdr(origSrcPort, origDstPort) - t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) - } - }) - } - }) - } - }) - } -} - -func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { - const addressSize = 6 - - tests := []struct { - name string - transportHdr func() header.ChecksummableTransport - proto tcpip.TransportProtocolNumber - }{ - { - name: "TCP", - transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, - proto: header.TCPProtocolNumber, - }, - { - name: "UDP", - transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, - proto: header.UDPProtocolNumber, - }, - } - - for i := 0; i < 1000; i++ { - permanent := randomAddress(addressSize) - old := randomAddress(addressSize) - new := randomAddress(addressSize) - - t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, fullChecksum := range []bool{true, false} { - t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { - initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) - if fullChecksum { - // TCP and UDP hold the 1s complement of the fully calculated - // checksum. - initialXSum = ^initialXSum - } - - h := test.transportHdr() - h.SetChecksum(initialXSum) - h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) - - got := h.Checksum() - if fullChecksum { - got = ^got - } - if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { - t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) - } - }) - } - }) - } - }) - } -} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go deleted file mode 100644 index bf9ccbf1a..000000000 --- a/pkg/tcpip/header/eth_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -func TestIsValidUnicastEthernetAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.LinkAddress - expected bool - }{ - { - "Nil", - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty", - tcpip.LinkAddress(""), - false, - }, - { - "InvalidLength", - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Unspecified", - unspecifiedEthernetAddress, - false, - }, - { - "Multicast", - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - false, - }, - { - "Valid", - tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), - true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { - t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) - } - }) - } -} - -func TestIsMulticastEthernetAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.LinkAddress - expected bool - }{ - { - "Nil", - tcpip.LinkAddress([]byte(nil)), - false, - }, - { - "Empty", - tcpip.LinkAddress(""), - false, - }, - { - "InvalidLength", - tcpip.LinkAddress("\x01\x02\x03"), - false, - }, - { - "Unspecified", - unspecifiedEthernetAddress, - false, - }, - { - "Multicast", - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - true, - }, - { - "Unicast", - tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := IsMulticastEthernetAddress(test.addr); got != test.expected { - t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected) - } - }) - } -} - -func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4 Multicast without 24th bit set", - addr: "\xe0\x7e\xdc\xba", - expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", - }, - { - name: "IPv4 Multicast with 24th bit set", - addr: "\xe0\xfe\xdc\xba", - expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { - t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr) - } - }) - } -} - -func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { - addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a") - if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { - t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) - } -} diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go new file mode 100644 index 000000000..d6dd58874 --- /dev/null +++ b/pkg/tcpip/header/header_state_autogen.go @@ -0,0 +1,74 @@ +// automatically generated by stateify. + +package header + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *SACKBlock) StateTypeName() string { + return "pkg/tcpip/header.SACKBlock" +} + +func (r *SACKBlock) StateFields() []string { + return []string{ + "Start", + "End", + } +} + +func (r *SACKBlock) beforeSave() {} + +// +checklocksignore +func (r *SACKBlock) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Start) + stateSinkObject.Save(1, &r.End) +} + +func (r *SACKBlock) afterLoad() {} + +// +checklocksignore +func (r *SACKBlock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Start) + stateSourceObject.Load(1, &r.End) +} + +func (t *TCPOptions) StateTypeName() string { + return "pkg/tcpip/header.TCPOptions" +} + +func (t *TCPOptions) StateFields() []string { + return []string{ + "TS", + "TSVal", + "TSEcr", + "SACKBlocks", + } +} + +func (t *TCPOptions) beforeSave() {} + +// +checklocksignore +func (t *TCPOptions) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.TS) + stateSinkObject.Save(1, &t.TSVal) + stateSinkObject.Save(2, &t.TSEcr) + stateSinkObject.Save(3, &t.SACKBlocks) +} + +func (t *TCPOptions) afterLoad() {} + +// +checklocksignore +func (t *TCPOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.TS) + stateSourceObject.Load(1, &t.TSVal) + stateSourceObject.Load(2, &t.TSEcr) + stateSourceObject.Load(3, &t.SACKBlocks) +} + +func init() { + state.Register((*SACKBlock)(nil)) + state.Register((*TCPOptions)(nil)) +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go deleted file mode 100644 index 575604928..000000000 --- a/pkg/tcpip/header/igmp_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -// TestIGMPHeader tests the functions within header.igmp -func TestIGMPHeader(t *testing.T) { - const maxRespTimeTenthSec = 0xF0 - b := []byte{ - 0x11, // IGMP Type, Membership Query - maxRespTimeTenthSec, // Maximum Response Time - 0xC0, 0xC0, // Checksum - 0x01, 0x02, 0x03, 0x04, // Group Address - } - - igmpHeader := header.IGMP(b) - - if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { - t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) - } - - if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want { - t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) - } - - if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { - t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) - } - - if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want { - t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) - } - - igmpType := header.IGMPv2MembershipReport - igmpHeader.SetType(igmpType) - if got := igmpHeader.Type(); got != igmpType { - t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) - } - if got := header.IGMPType(b[0]); got != igmpType { - t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) - } - - respTime := byte(0x02) - igmpHeader.SetMaxRespTime(respTime) - if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want { - t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) - } - - checksum := uint16(0x0102) - igmpHeader.SetChecksum(checksum) - if got := igmpHeader.Checksum(); got != checksum { - t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) - } - - groupAddress := testutil.MustParse4("4.3.2.1") - igmpHeader.SetGroupAddress(groupAddress) - if got := igmpHeader.GroupAddress(); got != groupAddress { - t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) - } -} - -// TestIGMPChecksum ensures that the checksum calculator produces the expected -// checksum. -func TestIGMPChecksum(t *testing.T) { - b := []byte{ - 0x11, // IGMP Type, Membership Query - 0xF0, // Maximum Response Time - 0xC0, 0xC0, // Checksum - 0x01, 0x02, 0x03, 0x04, // Group Address - } - - igmpHeader := header.IGMP(b) - - // Calculate the initial checksum after setting the checksum temporarily to 0 - // to avoid checksumming the checksum. - initialChecksum := igmpHeader.Checksum() - igmpHeader.SetChecksum(0) - checksum := ^header.Checksum(b, 0) - igmpHeader.SetChecksum(initialChecksum) - - if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { - t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) - } -} - -func TestDecisecondToDuration(t *testing.T) { - const valueInDeciseconds = 5 - if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want { - t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want) - } -} diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go deleted file mode 100644 index c02fe898b..000000000 --- a/pkg/tcpip/header/ipv4_test.go +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4OptionsSerializer(t *testing.T) { - optCases := []struct { - name string - option []header.IPv4SerializableOption - expect []byte - }{ - { - name: "NOP", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableNOPOption{}, - }, - expect: []byte{1, 0, 0, 0}, - }, - { - name: "ListEnd", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableListEndOption{}, - }, - expect: []byte{0, 0, 0, 0}, - }, - { - name: "RouterAlert", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableRouterAlertOption{}, - }, - expect: []byte{148, 4, 0, 0}, - }, { - name: "NOP and RouterAlert", - option: []header.IPv4SerializableOption{ - &header.IPv4SerializableNOPOption{}, - &header.IPv4SerializableRouterAlertOption{}, - }, - expect: []byte{1, 148, 4, 0, 0, 0, 0, 0}, - }, - } - - for _, opt := range optCases { - t.Run(opt.name, func(t *testing.T) { - s := header.IPv4OptionsSerializer(opt.option) - l := s.Length() - if got := len(opt.expect); got != int(l) { - t.Fatalf("s.Length() = %d, want = %d", got, l) - } - b := make([]byte, l) - for i := range b { - // Fill the buffer with full bytes to ensure padding is being set - // correctly. - b[i] = 0xFF - } - if serializedLength := s.Serialize(b); serializedLength != l { - t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l) - } - if diff := cmp.Diff(opt.expect, b); diff != "" { - t.Errorf("mismatched serialized option (-want +got):\n%s", diff) - } - }) - } -} - -// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested -// fields when options are supplied. -func TestIPv4EncodeOptions(t *testing.T) { - tests := []struct { - name string - numberOfNops int - encodedOptions header.IPv4Options // reply should look like this - wantIHL int - }{ - { - name: "valid no options", - wantIHL: header.IPv4MinimumSize, - }, - { - name: "one byte options", - numberOfNops: 1, - encodedOptions: header.IPv4Options{1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "two byte options", - numberOfNops: 2, - encodedOptions: header.IPv4Options{1, 1, 0, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "three byte options", - numberOfNops: 3, - encodedOptions: header.IPv4Options{1, 1, 1, 0}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "four byte options", - numberOfNops: 4, - encodedOptions: header.IPv4Options{1, 1, 1, 1}, - wantIHL: header.IPv4MinimumSize + 4, - }, - { - name: "five byte options", - numberOfNops: 5, - encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, - wantIHL: header.IPv4MinimumSize + 8, - }, - { - name: "thirty nine byte options", - numberOfNops: 39, - encodedOptions: header.IPv4Options{ - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 0, - }, - wantIHL: header.IPv4MinimumSize + 40, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops)) - for i := range serializeOpts { - serializeOpts[i] = &header.IPv4SerializableNOPOption{} - } - paddedOptionLength := serializeOpts.Length() - ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength) - hdr := buffer.NewPrependable(int(totalLen)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - // To check the padding works, poison the last byte of the options space. - if paddedOptionLength != serializeOpts.Length() { - ip.SetHeaderLength(uint8(ipHeaderLength)) - ip.Options()[paddedOptionLength-1] = 0xff - ip.SetHeaderLength(0) - } - ip.Encode(&header.IPv4Fields{ - Options: serializeOpts, - }) - options := ip.Options() - wantOptions := test.encodedOptions - if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { - t.Errorf("got IHL of %d, want %d", got, want) - } - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(wantOptions) == 0 && len(options) == 0 { - return - } - - if diff := cmp.Diff(wantOptions, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestIsV4LinkLocalUnicastAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid (lowest)", - addr: "\xa9\xfe\x00\x00", - expected: true, - }, - { - name: "Valid (highest)", - addr: "\xa9\xfe\xff\xff", - expected: true, - }, - { - name: "Invalid (before subnet)", - addr: "\xa9\xfd\xff\xff", - expected: false, - }, - { - name: "Invalid (after subnet)", - addr: "\xa9\xff\x00\x00", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestIsV4LinkLocalMulticastAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid (lowest)", - addr: "\xe0\x00\x00\x00", - expected: true, - }, - { - name: "Valid (highest)", - addr: "\xe0\x00\x00\xff", - expected: true, - }, - { - name: "Invalid (before subnet)", - addr: "\xdf\xff\xff\xff", - expected: false, - }, - { - name: "Invalid (after subnet)", - addr: "\xe0\x00\x01\x00", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go deleted file mode 100644 index 65adc6250..000000000 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ /dev/null @@ -1,1346 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "bytes" - "errors" - "io" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// Equal returns true of a and b are equivalent. -// -// Note, Equal will return true if a and b hold the same Identifier value and -// contain the same bytes in Buf, even if the bytes are split across views -// differently. -// -// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported -// fields. -func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool { - return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView()) -} - -// Equal returns true of a and b are equivalent. -// -// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs. -// -// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported -// fields. -func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool { - return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr) -} - -// Equal returns true of a and b are equivalent. -// -// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs. -// -// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported -// fields. -func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool { - return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr) -} - -func TestIPv6UnknownExtHdrOption(t *testing.T) { - tests := []struct { - name string - identifier IPv6ExtHdrOptionIdentifier - expectedUnknownAction IPv6OptionUnknownAction - }{ - { - name: "Skip with zero LSBs", - identifier: 0, - expectedUnknownAction: IPv6OptionUnknownActionSkip, - }, - { - name: "Discard with zero LSBs", - identifier: 64, - expectedUnknownAction: IPv6OptionUnknownActionDiscard, - }, - { - name: "Discard and ICMP with zero LSBs", - identifier: 128, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP, - }, - { - name: "Discard and ICMP for non multicast destination with zero LSBs", - identifier: 192, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, - }, - { - name: "Skip with non-zero LSBs", - identifier: 63, - expectedUnknownAction: IPv6OptionUnknownActionSkip, - }, - { - name: "Discard with non-zero LSBs", - identifier: 127, - expectedUnknownAction: IPv6OptionUnknownActionDiscard, - }, - { - name: "Discard and ICMP with non-zero LSBs", - identifier: 191, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP, - }, - { - name: "Discard and ICMP for non multicast destination with non-zero LSBs", - identifier: 255, - expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}} - if a := opt.UnknownAction(); a != test.expectedUnknownAction { - t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction) - } - }) - } - -} - -func TestIPv6OptionsExtHdrIterErr(t *testing.T) { - tests := []struct { - name string - bytes []byte - err error - }{ - { - name: "Single unknown with zero length", - bytes: []byte{255, 0}, - }, - { - name: "Single unknown with non-zero length", - bytes: []byte{255, 3, 1, 2, 3}, - }, - { - name: "Two options", - bytes: []byte{ - 255, 0, - 254, 1, 1, - }, - }, - { - name: "Three options", - bytes: []byte{ - 255, 0, - 254, 1, 1, - 253, 4, 2, 3, 4, 5, - }, - }, - { - name: "Single unknown only identifier", - bytes: []byte{255}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Single unknown too small with length = 1", - bytes: []byte{255, 1}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Single unknown too small with length = 2", - bytes: []byte{255, 2, 1}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid first with second unknown only identifier", - bytes: []byte{ - 255, 0, - 254, - }, - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid first with second unknown missing data", - bytes: []byte{ - 255, 0, - 254, 1, - }, - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid first with second unknown too small", - bytes: []byte{ - 255, 0, - 254, 2, 1, - }, - err: io.ErrUnexpectedEOF, - }, - { - name: "One Pad1", - bytes: []byte{0}, - }, - { - name: "Multiple Pad1", - bytes: []byte{0, 0, 0}, - }, - { - name: "Multiple PadN", - bytes: []byte{ - // Pad3 - 1, 1, 1, - - // Pad5 - 1, 3, 1, 2, 3, - }, - }, - { - name: "Pad5 too small middle of data buffer", - bytes: []byte{1, 3, 1, 2}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Pad5 no data", - bytes: []byte{1, 3}, - err: io.ErrUnexpectedEOF, - }, - { - name: "Router alert without data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with partial data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with partial data and Pad1", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with extra data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3}, - err: ErrMalformedIPv6ExtHdrOption, - }, - { - name: "Router alert with missing data", - bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1}, - err: io.ErrUnexpectedEOF, - }, - } - - check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { - for i := 0; ; i++ { - _, done, err := it.Next() - if err != nil { - // If we encountered a non-nil error while iterating, make sure it is - // is the same error as expectedErr. - if !errors.Is(err, expectedErr) { - t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr) - } - - return - } - if done { - // If we are done (without an error), make sure that we did not expect - // an error. - if expectedErr != nil { - t.Fatalf("expected error when iterating; want = %s", expectedErr) - } - - return - } - } - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Run("Hop By Hop", func(t *testing.T) { - extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - check(t, extHdr.Iter(), test.err) - }) - - t.Run("Destination", func(t *testing.T) { - extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - check(t, extHdr.Iter(), test.err) - }) - }) - } -} - -func TestIPv6OptionsExtHdrIter(t *testing.T) { - tests := []struct { - name string - bytes []byte - expected []IPv6ExtHdrOption - }{ - { - name: "Single unknown with zero length", - bytes: []byte{255, 0}, - expected: []IPv6ExtHdrOption{ - &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}}, - }, - }, - { - name: "Single unknown with non-zero length", - bytes: []byte{255, 3, 1, 2, 3}, - expected: []IPv6ExtHdrOption{ - &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}}, - }, - }, - { - name: "Single Pad1", - bytes: []byte{0}, - }, - { - name: "Two Pad1", - bytes: []byte{0, 0}, - }, - { - name: "Single Pad3", - bytes: []byte{1, 1, 1}, - }, - { - name: "Single Pad5", - bytes: []byte{1, 3, 1, 2, 3}, - }, - { - name: "Multiple Pad", - bytes: []byte{ - // Pad1 - 0, - - // Pad2 - 1, 0, - - // Pad3 - 1, 1, 1, - - // Pad4 - 1, 2, 1, 2, - - // Pad5 - 1, 3, 1, 2, 3, - }, - }, - { - name: "Multiple options", - bytes: []byte{ - // Pad1 - 0, - - // Unknown - 255, 0, - - // Pad2 - 1, 0, - - // Unknown - 254, 1, 1, - - // Pad3 - 1, 1, 1, - - // Unknown - 253, 4, 2, 3, 4, 5, - - // Pad4 - 1, 2, 1, 2, - }, - expected: []IPv6ExtHdrOption{ - &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}}, - &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}}, - &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}}, - }, - }, - } - - checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) { - for i, e := range expected { - opt, done, err := it.Next() - if err != nil { - t.Errorf("(i=%d) Next(): %s", i, err) - } - if done { - t.Errorf("(i=%d) unexpectedly done iterating", i) - } - if diff := cmp.Diff(e, opt); diff != "" { - t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff) - } - - if t.Failed() { - t.FailNow() - } - } - - opt, done, err := it.Next() - if err != nil { - t.Errorf("(last) Next(): %s", err) - } - if !done { - t.Errorf("(last) iterator unexpectedly not done") - } - if opt != nil { - t.Errorf("(last) got Next() = %T, want = nil", opt) - } - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Run("Hop By Hop", func(t *testing.T) { - extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - checkIter(t, extHdr.Iter(), test.expected) - }) - - t.Run("Destination", func(t *testing.T) { - extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes} - checkIter(t, extHdr.Iter(), test.expected) - }) - }) - } -} - -func TestIPv6RoutingExtHdr(t *testing.T) { - tests := []struct { - name string - bytes []byte - segmentsLeft uint8 - }{ - { - name: "Zeroes", - bytes: []byte{0, 0, 0, 0, 0, 0}, - segmentsLeft: 0, - }, - { - name: "Ones", - bytes: []byte{1, 1, 1, 1, 1, 1}, - segmentsLeft: 1, - }, - { - name: "Mixed", - bytes: []byte{1, 2, 3, 4, 5, 6}, - segmentsLeft: 2, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - extHdr := IPv6RoutingExtHdr(test.bytes) - if got := extHdr.SegmentsLeft(); got != test.segmentsLeft { - t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft) - } - }) - } -} - -func TestIPv6FragmentExtHdr(t *testing.T) { - tests := []struct { - name string - bytes [6]byte - fragmentOffset uint16 - more bool - id uint32 - }{ - { - name: "Zeroes", - bytes: [6]byte{0, 0, 0, 0, 0, 0}, - fragmentOffset: 0, - more: false, - id: 0, - }, - { - name: "Ones", - bytes: [6]byte{0, 9, 0, 0, 0, 1}, - fragmentOffset: 1, - more: true, - id: 1, - }, - { - name: "Mixed", - bytes: [6]byte{68, 9, 128, 4, 2, 1}, - fragmentOffset: 2177, - more: true, - id: 2147746305, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - extHdr := IPv6FragmentExtHdr(test.bytes) - if got := extHdr.FragmentOffset(); got != test.fragmentOffset { - t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset) - } - if got := extHdr.More(); got != test.more { - t.Errorf("got More() = %t, want = %t", got, test.more) - } - if got := extHdr.ID(); got != test.id { - t.Errorf("got ID() = %d, want = %d", got, test.id) - } - }) - } -} - -func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView { - size := 0 - var vs []buffer.View - - for _, b := range bs { - vs = append(vs, buffer.View(b)) - size += len(b) - } - - return buffer.NewVectorisedView(size, vs) -} - -func TestIPv6ExtHdrIterErr(t *testing.T) { - tests := []struct { - name string - firstNextHdr IPv6ExtensionHeaderIdentifier - payload buffer.VectorisedView - err error - }{ - { - name: "Upper layer only without data", - firstNextHdr: 255, - }, - { - name: "Upper layer only with data", - firstNextHdr: 255, - payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), - }, - { - name: "No next header", - firstNextHdr: IPv6NoNextHeaderIdentifier, - }, - { - name: "No next header with data", - firstNextHdr: IPv6NoNextHeaderIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), - }, - { - name: "Valid single hop by hop", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}), - }, - { - name: "Hop by hop too small", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid single fragment", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}), - }, - { - name: "Fragment too small", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid single destination", - firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}), - }, - { - name: "Destination too small", - firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid single routing", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}), - }, - { - name: "Valid single routing across views", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}), - }, - { - name: "Routing too small with zero length field", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Valid routing with non-zero length field", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}), - }, - { - name: "Valid routing with non-zero length field across views", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}), - }, - { - name: "Routing too small with non-zero length field", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Routing too small with non-zero length field across views", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}), - err: io.ErrUnexpectedEOF, - }, - { - name: "Mixed", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // (Atomic) Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 255, 4, 1, 2, 3, 4, - - // Upper layer data. - 1, 2, 3, 4, - }), - }, - { - name: "Mixed without upper layer data", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // (Atomic) Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 255, 4, 1, 2, 3, 4, - }), - }, - { - name: "Mixed without upper layer data but last ext hdr too small", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // (Atomic) Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 255, 4, 1, 2, 3, - }), - err: io.ErrUnexpectedEOF, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload) - - for i := 0; ; i++ { - _, done, err := it.Next() - if err != nil { - // If we encountered a non-nil error while iterating, make sure it is - // is the same error as test.err. - if !errors.Is(err, test.err) { - t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err) - } - - return - } - if done { - // If we are done (without an error), make sure that we did not expect - // an error. - if test.err != nil { - t.Fatalf("expected error when iterating; want = %s", test.err) - } - - return - } - } - }) - } -} - -func TestIPv6ExtHdrIter(t *testing.T) { - routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4}) - upperLayerData := buffer.View([]byte{1, 2, 3, 4}) - tests := []struct { - name string - firstNextHdr IPv6ExtensionHeaderIdentifier - payload buffer.VectorisedView - expected []IPv6PayloadHeader - }{ - // With a non-atomic fragment that is not the first fragment, the payload - // after the fragment will not be parsed because the payload is expected to - // only hold upper layer data. - { - name: "hopbyhop - fragment (not first) - routing - upper", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Fragment extension header. - // - // More = 1, Fragment Offset = 2117, ID = 2147746305 - uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, - - // Routing extension header. - // - // Even though we have a routing ext header here, it should be - // be interpretted as raw bytes as only the first fragment is expected - // to hold headers. - 255, 0, 1, 2, 3, 4, 5, 6, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), - IPv6RawPayloadHeader{ - Identifier: IPv6RoutingExtHdrIdentifier, - Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), - }, - }, - }, - { - name: "hopbyhop - fragment (first) - routing - upper", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Fragment extension header. - // - // More = 1, Fragment Offset = 0, ID = 2147746305 - uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1, - - // Routing extension header. - 255, 0, 1, 2, 3, 4, 5, 6, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}), - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6RawPayloadHeader{ - Identifier: 255, - Buf: upperLayerData.ToVectorisedView(), - }, - }, - }, - { - name: "fragment - routing - upper (across views)", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, - - // Routing extension header. - 255, 0, 1, 2}, []byte{3, 4, 5, 6, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), - IPv6RawPayloadHeader{ - Identifier: IPv6RoutingExtHdrIdentifier, - Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), - }, - }, - }, - - // If we have an atomic fragment, the payload following the fragment - // extension header should be parsed normally. - { - name: "atomic fragment - routing - destination - upper", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, - - // Routing extension header. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Destination Options extension header. - 255, 0, 1, 4, 1, 2, 3, 4, - - // Upper layer data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6RawPayloadHeader{ - Identifier: 255, - Buf: upperLayerData.ToVectorisedView(), - }, - }, - }, - { - name: "atomic fragment - routing - upper (across views)", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, - - // Routing extension header. - 255, 0, 1, 2}, []byte{3, 4, 5, 6, - - // Upper layer data. - 1, 2}, []byte{3, 4}), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6RawPayloadHeader{ - Identifier: 255, - Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), - }, - }, - }, - { - name: "atomic fragment - destination - no next header", - firstNextHdr: IPv6FragmentExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Fragment extension header. - // - // Res (Reserved) bits are 1 which should not affect anything. - uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1, - - // Destination Options extension header. - uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - }, - }, - { - name: "routing - atomic fragment - no next header", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Routing extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - }, - }, - { - name: "routing - atomic fragment - no next header (across views)", - firstNextHdr: IPv6RoutingExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Routing extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Fragment extension header. - // - // Reserved bits are 1 which should not affect anything. - uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), - }, - }, - { - name: "hopbyhop - routing - fragment - no next header", - firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, - payload: makeVectorisedViewFromByteBuffers([]byte{ - // Hop By Hop Options extension header. - uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, - - // Routing extension header. - uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, - - // Fragment extension header. - // - // Fragment Offset = 32; Res = 6. - uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1, - - // Random data. - 1, 2, 3, 4, - }), - expected: []IPv6PayloadHeader{ - IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, - IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), - IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}), - IPv6RawPayloadHeader{ - Identifier: IPv6NoNextHeaderIdentifier, - Buf: upperLayerData.ToVectorisedView(), - }, - }, - }, - - // Test the raw payload for common transport layer protocol numbers. - { - name: "TCP raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "UDP raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "ICMPv4 raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "ICMPv6 raw payload", - firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "Unknwon next header raw payload", - firstNextHdr: 255, - payload: makeVectorisedViewFromByteBuffers(upperLayerData), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: 255, - Buf: upperLayerData.ToVectorisedView(), - }}, - }, - { - name: "Unknwon next header raw payload (across views)", - firstNextHdr: 255, - payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), - expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ - Identifier: 255, - Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), - }}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload) - - for i, e := range test.expected { - extHdr, done, err := it.Next() - if err != nil { - t.Errorf("(i=%d) Next(): %s", i, err) - } - if done { - t.Errorf("(i=%d) unexpectedly done iterating", i) - } - if diff := cmp.Diff(e, extHdr); diff != "" { - t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff) - } - - if t.Failed() { - t.FailNow() - } - } - - extHdr, done, err := it.Next() - if err != nil { - t.Errorf("(last) Next(): %s", err) - } - if !done { - t.Errorf("(last) iterator unexpectedly not done") - } - if extHdr != nil { - t.Errorf("(last) got Next() = %T, want = nil", extHdr) - } - }) - } -} - -var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) - -// dummyHbHOptionSerializer provides a generic implementation of -// IPv6SerializableHopByHopOption for use in tests. -type dummyHbHOptionSerializer struct { - id IPv6ExtHdrOptionIdentifier - payload []byte - align int - alignOffset int -} - -// identifier implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { - return s.id -} - -// length implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) length() uint8 { - return uint8(len(s.payload)) -} - -// alignment implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) alignment() (int, int) { - align := 1 - if s.align != 0 { - align = s.align - } - return align, s.alignOffset -} - -// serializeInto implements IPv6SerializableHopByHopOption. -func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { - return uint8(copy(b, s.payload)) -} - -func TestIPv6HopByHopSerializer(t *testing.T) { - validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { - t.Helper() - dummy, ok := serializable.(*dummyHbHOptionSerializer) - if !ok { - t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) - } - unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) - if !ok { - t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) - } - if dummy.id != unknown.Identifier { - t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) - } - if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { - t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) - } - } - tests := []struct { - name string - nextHeader uint8 - options []IPv6SerializableHopByHopOption - expect []byte - validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) - }{ - { - name: "single option", - nextHeader: 13, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 15, - payload: []byte{9, 8, 7, 6}, - }, - }, - expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, - validate: validateDummies, - }, - { - name: "short option padN zero", - nextHeader: 88, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5}, - }, - }, - expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, - validate: validateDummies, - }, - { - name: "short option pad1", - nextHeader: 11, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 33, - payload: []byte{1, 2, 3}, - }, - }, - expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, - validate: validateDummies, - }, - { - name: "long option padN", - nextHeader: 55, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 77, - payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, - }, - }, - expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, - validate: validateDummies, - }, - { - name: "two options", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 11, - payload: []byte{1, 2, 3}, - }, - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5, 6}, - }, - }, - expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, - validate: validateDummies, - }, - { - name: "two options align 2n", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 11, - payload: []byte{1, 2, 3}, - }, - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5, 6}, - align: 2, - }, - }, - expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, - validate: validateDummies, - }, - { - name: "two options align 8n+1", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{ - &dummyHbHOptionSerializer{ - id: 11, - payload: []byte{1, 2}, - }, - &dummyHbHOptionSerializer{ - id: 22, - payload: []byte{4, 5, 6}, - align: 8, - alignOffset: 1, - }, - }, - expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, - validate: validateDummies, - }, - { - name: "no options", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{}, - expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, - }, - { - name: "Router Alert", - nextHeader: 33, - options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, - expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, - validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { - t.Helper() - routerAlert, ok := deserialized.(*IPv6RouterAlertOption) - if !ok { - t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) - } - if routerAlert.Value != IPv6RouterAlertMLD { - t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) - } - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := IPv6SerializableHopByHopExtHdr(test.options) - length := s.length() - if length != len(test.expect) { - t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) - } - b := make([]byte, length) - for i := range b { - // Fill the buffer with ones to ensure all padding is correctly set. - b[i] = 0xFF - } - if got := s.serializeInto(test.nextHeader, b); got != length { - t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) - } - if diff := cmp.Diff(test.expect, b); diff != "" { - t.Fatalf("serialization mismatch (-want +got):\n%s", diff) - } - - // Deserialize the options and verify them. - optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit - iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() - for _, testOpt := range test.options { - opt, done, err := iter.Next() - if err != nil { - t.Fatalf("iter.Next(): %s", err) - } - if done { - t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) - } - test.validate(t, testOpt, opt) - } - opt, done, err := iter.Next() - if err != nil { - t.Fatalf("iter.Next(): %s", err) - } - if !done { - t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) - } - }) - } -} - -var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) - -// dummyIPv6ExtHdrSerializer provides a generic implementation of -// IPv6SerializableExtHdr for use in tests. -// -// The dummy header always carries the nextHeader value in the first byte. -type dummyIPv6ExtHdrSerializer struct { - id IPv6ExtensionHeaderIdentifier - headerContents []byte -} - -// identifier implements IPv6SerializableExtHdr. -func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { - return s.id -} - -// length implements IPv6SerializableExtHdr. -func (s *dummyIPv6ExtHdrSerializer) length() int { - return len(s.headerContents) + 1 -} - -// serializeInto implements IPv6SerializableExtHdr. -func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { - b[0] = nextHeader - return copy(b[1:], s.headerContents) + 1 -} - -func TestIPv6ExtHdrSerializer(t *testing.T) { - tests := []struct { - name string - headers []IPv6SerializableExtHdr - nextHeader tcpip.TransportProtocolNumber - expectSerialized []byte - expectNextHeader uint8 - }{ - { - name: "one header", - headers: []IPv6SerializableExtHdr{ - &dummyIPv6ExtHdrSerializer{ - id: 15, - headerContents: []byte{1, 2, 3, 4}, - }, - }, - nextHeader: TCPProtocolNumber, - expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, - expectNextHeader: 15, - }, - { - name: "two headers", - headers: []IPv6SerializableExtHdr{ - &dummyIPv6ExtHdrSerializer{ - id: 22, - headerContents: []byte{1, 2, 3}, - }, - &dummyIPv6ExtHdrSerializer{ - id: 23, - headerContents: []byte{4, 5, 6}, - }, - }, - nextHeader: ICMPv6ProtocolNumber, - expectSerialized: []byte{ - 23, 1, 2, 3, - byte(ICMPv6ProtocolNumber), 4, 5, 6, - }, - expectNextHeader: 22, - }, - { - name: "no headers", - headers: []IPv6SerializableExtHdr{}, - nextHeader: UDPProtocolNumber, - expectSerialized: []byte{}, - expectNextHeader: byte(UDPProtocolNumber), - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := IPv6ExtHdrSerializer(test.headers) - l := s.Length() - if got, want := l, len(test.expectSerialized); got != want { - t.Fatalf("got serialized length = %d, want = %d", got, want) - } - b := make([]byte, l) - for i := range b { - // Fill the buffer with garbage to make sure we're writing to all bytes. - b[i] = 0xFF - } - nextHeader, serializedLen := s.Serialize(test.nextHeader, b) - if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { - t.Errorf( - "got s.Serialize(..) = (%d, %d), want = (%d, %d)", - nextHeader, - serializedLen, - test.expectNextHeader, - len(test.expectSerialized), - ) - } - if diff := cmp.Diff(test.expectSerialized, b); diff != "" { - t.Errorf("serialization mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go deleted file mode 100644 index 89be84068..000000000 --- a/pkg/tcpip/header/ipv6_test.go +++ /dev/null @@ -1,457 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "bytes" - "crypto/sha256" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - -var ( - linkLocalAddr = testutil.MustParse6("fe80::1") - linkLocalMulticastAddr = testutil.MustParse6("ff02::1") - uniqueLocalAddr1 = testutil.MustParse6("fc00::1") - uniqueLocalAddr2 = testutil.MustParse6("fd00::2") - globalAddr = testutil.MustParse6("a000::1") -) - -func TestEthernetAdddressToModifiedEUI64(t *testing.T) { - expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} - - if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { - t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) - } - - var buf [header.IIDSize]byte - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) - if diff := cmp.Diff(expectedIID, buf); diff != "" { - t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) - } -} - -func TestLinkLocalAddr(t *testing.T) { - if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want { - t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) - } -} - -func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte - if n, err := rand.Read(secretKeyBuf[:]); err != nil { - t.Fatalf("rand.Read(_): %s", err) - } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) - } - - tests := []struct { - name string - prefix tcpip.Subnet - nicName string - dadCounter uint8 - secretKey []byte - }{ - { - name: "SecretKey of minimum size", - prefix: header.IPv6LinkLocalPrefix.Subnet(), - nicName: "eth0", - dadCounter: 0, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], - }, - { - name: "SecretKey of less than minimum size", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "eth10", - dadCounter: 1, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], - }, - { - name: "SecretKey of more than minimum size", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "eth11", - dadCounter: 2, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], - }, - { - name: "Nil SecretKey and empty nicName", - prefix: func() tcpip.Subnet { - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - return addrWithPrefix.Subnet() - }(), - nicName: "", - dadCounter: 3, - secretKey: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - h := sha256.New() - h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) - h.Write([]byte(test.nicName)) - h.Write([]byte{test.dadCounter}) - if k := test.secretKey; k != nil { - h.Write(k) - } - var hashSum [sha256.Size]byte - h.Sum(hashSum[:0]) - want := hashSum[:header.IIDSize] - - // Passing a nil buffer should result in a new buffer returned with the - // IID. - if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { - t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) - } - - // Passing a buffer with sufficient capacity for the IID should populate - // the buffer provided. - var iidBuf [header.IIDSize]byte - if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { - t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) - } - if got := iidBuf[:]; !bytes.Equal(got, want) { - t.Errorf("got iidBuf = %x, want = %x", got, want) - } - }) - } -} - -func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte - if n, err := rand.Read(secretKeyBuf[:]); err != nil { - t.Fatalf("rand.Read(_): %s", err) - } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) - } - - prefix := header.IPv6LinkLocalPrefix.Subnet() - - tests := []struct { - name string - prefix tcpip.Subnet - nicName string - dadCounter uint8 - secretKey []byte - }{ - { - name: "SecretKey of minimum size", - nicName: "eth0", - dadCounter: 0, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], - }, - { - name: "SecretKey of less than minimum size", - nicName: "eth10", - dadCounter: 1, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], - }, - { - name: "SecretKey of more than minimum size", - nicName: "eth11", - dadCounter: 2, - secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], - }, - { - name: "Nil SecretKey and empty nicName", - nicName: "", - dadCounter: 3, - secretKey: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - addrBytes := [header.IPv6AddressSize]byte{ - 0: 0xFE, - 1: 0x80, - } - - want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( - addrBytes[:header.IIDOffsetInIPv6Address], - prefix, - test.nicName, - test.dadCounter, - test.secretKey, - )) - - if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { - t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) - } - }) - } -} - -func TestIsV6LinkLocalMulticastAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Link Local Multicast", - addr: linkLocalMulticastAddr, - expected: true, - }, - { - name: "Valid Link Local Multicast with flags", - addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - expected: true, - }, - { - name: "Link Local Unicast", - addr: linkLocalAddr, - expected: false, - }, - { - name: "IPv4 Multicast", - addr: "\xe0\x00\x00\x01", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestIsV6LinkLocalUnicastAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Link Local Unicast", - addr: linkLocalAddr, - expected: true, - }, - { - name: "Link Local Multicast", - addr: linkLocalMulticastAddr, - expected: false, - }, - { - name: "Unique Local", - addr: uniqueLocalAddr1, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4 Link Local", - addr: "\xa9\xfe\x00\x01", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - -func TestScopeForIPv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - scope header.IPv6AddressScope - err tcpip.Error - }{ - { - name: "Unique Local", - addr: uniqueLocalAddr1, - scope: header.GlobalScope, - err: nil, - }, - { - name: "Link Local Unicast", - addr: linkLocalAddr, - scope: header.LinkLocalScope, - err: nil, - }, - { - name: "Link Local Multicast", - addr: linkLocalMulticastAddr, - scope: header.LinkLocalScope, - err: nil, - }, - { - name: "Global", - addr: globalAddr, - scope: header.GlobalScope, - err: nil, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - scope: header.GlobalScope, - err: &tcpip.ErrBadAddress{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := header.ScopeForIPv6Address(test.addr) - if diff := cmp.Diff(test.err, err); diff != "" { - t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff) - } - if got != test.scope { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) - } - }) - } -} - -func TestSolicitedNodeAddr(t *testing.T) { - tests := []struct { - addr tcpip.Address - want tcpip.Address - }{ - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", - }, - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", - }, - { - addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03", - want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03", - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { - if got := header.SolicitedNodeAddr(test.addr); got != test.want { - t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want) - } - }) - } -} - -func TestV6MulticastScope(t *testing.T) { - tests := []struct { - addr tcpip.Address - want header.IPv6MulticastScope - }{ - { - addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6Reserved0MulticastScope, - }, - { - addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6InterfaceLocalMulticastScope, - }, - { - addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6LinkLocalMulticastScope, - }, - { - addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6RealmLocalMulticastScope, - }, - { - addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6AdminLocalMulticastScope, - }, - { - addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6SiteLocalMulticastScope, - }, - { - addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(6), - }, - { - addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(7), - }, - { - addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6OrganizationLocalMulticastScope, - }, - { - addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(9), - }, - { - addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(10), - }, - { - addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(11), - }, - { - addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(12), - }, - { - addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6MulticastScope(13), - }, - { - addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6GlobalMulticastScope, - }, - { - addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - want: header.IPv6ReservedFMulticastScope, - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { - if got := header.V6MulticastScope(test.addr); got != test.want { - t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want) - } - }) - } -} diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go deleted file mode 100644 index b5540bf66..000000000 --- a/pkg/tcpip/header/ipversion_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4(t *testing.T) { - b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) - - const want = header.IPv4Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestIPv6(t *testing.T) { - b := header.IPv6(make([]byte, header.IPv6MinimumSize)) - b.Encode(&header.IPv6Fields{}) - - const want = header.IPv6Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestOtherVersion(t *testing.T) { - const want = header.IPv4Version + header.IPv6Version - b := make([]byte, 1) - b[0] = want << 4 - - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestTooShort(t *testing.T) { - b := make([]byte, 1) - b[0] = (header.IPv4Version + header.IPv6Version) << 4 - - // Get the version of a zero-length slice. - const want = -1 - if v := header.IPVersion(b[:0]); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } - - // Get the version of a nil slice. - if v := header.IPVersion(nil); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go deleted file mode 100644 index 0cecf10d4..000000000 --- a/pkg/tcpip/header/mld_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "encoding/binary" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestMLD(t *testing.T) { - b := []byte{ - // Maximum Response Delay - 0, 0, - - // Reserved - 0, 0, - - // MulticastAddress - 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, - } - - const maxRespDelay = 513 - binary.BigEndian.PutUint16(b, maxRespDelay) - - mld := MLD(b) - - if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want { - t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) - } - - const newMaxRespDelay = 1234 - mld.SetMaximumResponseDelay(newMaxRespDelay) - if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want { - t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) - } - - if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { - t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want) - } - - multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}) - mld.SetMulticastAddress(multicastAddress) - if got := mld.MulticastAddress(); got != multicastAddress { - t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress) - } -} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go deleted file mode 100644 index 2a897e938..000000000 --- a/pkg/tcpip/header/ndp_test.go +++ /dev/null @@ -1,1748 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "regexp" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. -func TestNDPNeighborSolicit(t *testing.T) { - b := []byte{ - 0, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - ns := NDPNeighborSolicit(b) - addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") - if got := ns.TargetAddress(); got != addr { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) - } - - // Test updating the Target Address. - addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") - ns.SetTargetAddress(addr2) - if got := ns.TargetAddress(); got != addr2 { - t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } -} - -func TestNDPRouteInformationOption(t *testing.T) { - tests := []struct { - name string - - length uint8 - prefixLength uint8 - prf NDPRoutePreference - lifetimeS uint32 - prefixBytes []byte - expectedPrefix tcpip.Subnet - - expectedErr error - }{ - { - name: "Length=1 with Prefix Length = 0", - length: 1, - prefixLength: 0, - prf: MediumRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: IPv6EmptySubnet, - }, - { - name: "Length=1 but Prefix Length > 0", - length: 1, - prefixLength: 1, - prf: MediumRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "Length=2 with Prefix Length = 0", - length: 2, - prefixLength: 0, - prf: MediumRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: IPv6EmptySubnet, - }, - { - name: "Length=2 with Prefix Length in [1, 64] (1)", - length: 2, - prefixLength: 1, - prf: LowRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), - PrefixLen: 1, - }.Subnet(), - }, - { - name: "Length=2 with Prefix Length in [1, 64] (64)", - length: 2, - prefixLength: 64, - prf: HighRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), - PrefixLen: 64, - }.Subnet(), - }, - { - name: "Length=2 with Prefix Length > 64", - length: 2, - prefixLength: 65, - prf: HighRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "Length=3 with Prefix Length = 0", - length: 3, - prefixLength: 0, - prf: MediumRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: IPv6EmptySubnet, - }, - { - name: "Length=3 with Prefix Length in [1, 64] (1)", - length: 3, - prefixLength: 1, - prf: LowRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), - PrefixLen: 1, - }.Subnet(), - }, - { - name: "Length=3 with Prefix Length in [1, 64] (64)", - length: 3, - prefixLength: 64, - prf: HighRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), - PrefixLen: 64, - }.Subnet(), - }, - { - name: "Length=3 with Prefix Length in [65, 128] (65)", - length: 3, - prefixLength: 65, - prf: HighRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), - PrefixLen: 65, - }.Subnet(), - }, - { - name: "Length=3 with Prefix Length in [65, 128] (128)", - length: 3, - prefixLength: 128, - prf: HighRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), - PrefixLen: 128, - }.Subnet(), - }, - { - name: "Length=3 with (invalid) Prefix Length > 128", - length: 3, - prefixLength: 129, - prf: HighRoutePreference, - lifetimeS: 1, - prefixBytes: nil, - expectedErr: ErrNDPOptMalformedBody, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - expectedRouteInformationBytes := [...]byte{ - // Type, Length - 24, test.length, - - // Prefix Length, Prf - uint8(test.prefixLength), uint8(test.prf) << 3, - - // Route Lifetime - 0, 0, 0, 0, - - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - } - binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS) - _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes) - - opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits]) - it, err := opts.Iter(false) - if err != nil { - t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) - } - opt, done, err := it.Next() - if !errors.Is(err, test.expectedErr) { - t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr) - } - if want := test.expectedErr != nil; done != want { - t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want) - } - if test.expectedErr != nil { - return - } - - if got := opt.kind(); got != ndpRouteInformationType { - t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType) - } - - ri, ok := opt.(NDPRouteInformation) - if !ok { - t.Fatalf("got opt = %T, want = NDPRouteInformation", opt) - } - - if got := ri.PrefixLength(); got != test.prefixLength { - t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength) - } - if got := ri.RoutePreference(); got != test.prf { - t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf) - } - if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want { - t.Errorf("got RouteLifetime() = %s, want = %s", got, want) - } - if got, err := ri.Prefix(); err != nil { - t.Errorf("Prefix(): %s", err) - } else if got != test.expectedPrefix { - t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix) - } - - // Iterator should not return anything else. - { - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next() = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) - } - } - }) - } -} - -// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. -func TestNDPNeighborAdvert(t *testing.T) { - b := []byte{ - 160, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - } - - // Test getting the Target Address. - na := NDPNeighborAdvert(b) - addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") - if got := na.TargetAddress(); got != addr { - t.Errorf("got TargetAddress = %s, want %s", got, addr) - } - - // Test getting the Router Flag. - if got := na.RouterFlag(); !got { - t.Errorf("got RouterFlag = false, want = true") - } - - // Test getting the Solicited Flag. - if got := na.SolicitedFlag(); got { - t.Errorf("got SolicitedFlag = true, want = false") - } - - // Test getting the Override Flag. - if got := na.OverrideFlag(); !got { - t.Errorf("got OverrideFlag = false, want = true") - } - - // Test updating the Target Address. - addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") - na.SetTargetAddress(addr2) - if got := na.TargetAddress(); got != addr2 { - t.Errorf("got TargetAddress = %s, want %s", got, addr2) - } - // Make sure the address got updated in the backing buffer. - if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) - } - - // Test updating the Router Flag. - na.SetRouterFlag(false) - if got := na.RouterFlag(); got { - t.Errorf("got RouterFlag = true, want = false") - } - - // Test updating the Solicited Flag. - na.SetSolicitedFlag(true) - if got := na.SolicitedFlag(); !got { - t.Errorf("got SolicitedFlag = false, want = true") - } - - // Test updating the Override Flag. - na.SetOverrideFlag(false) - if got := na.OverrideFlag(); got { - t.Errorf("got OverrideFlag = true, want = false") - } - - // Make sure flags got updated in the backing buffer. - if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64", got) - } -} - -func TestNDPRouterAdvert(t *testing.T) { - tests := []struct { - hopLimit uint8 - managedFlag, otherConfFlag bool - prf NDPRoutePreference - routerLifetimeS uint16 - reachableTimeMS, retransTimerMS uint32 - }{ - { - hopLimit: 1, - managedFlag: false, - otherConfFlag: true, - prf: HighRoutePreference, - routerLifetimeS: 2, - reachableTimeMS: 3, - retransTimerMS: 4, - }, - { - hopLimit: 64, - managedFlag: true, - otherConfFlag: false, - prf: LowRoutePreference, - routerLifetimeS: 258, - reachableTimeMS: 78492, - retransTimerMS: 13213, - }, - } - - for i, test := range tests { - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - flags := uint8(0) - if test.managedFlag { - flags |= 1 << 7 - } - if test.otherConfFlag { - flags |= 1 << 6 - } - flags |= uint8(test.prf) << 3 - - b := []byte{ - test.hopLimit, flags, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, - } - binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS) - binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS) - binary.BigEndian.PutUint32(b[8:], test.retransTimerMS) - - ra := NDPRouterAdvert(b) - - if got := ra.CurrHopLimit(); got != test.hopLimit { - t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit) - } - - if got := ra.ManagedAddrConfFlag(); got != test.managedFlag { - t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag) - } - - if got := ra.OtherConfFlag(); got != test.otherConfFlag { - t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag) - } - - if got := ra.DefaultRouterPreference(); got != test.prf { - t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf) - } - - if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want { - t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want) - } - - if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want { - t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want) - } - - if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want { - t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want) - } - }) - } -} - -// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "SLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "SLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - sll := NDPSourceLinkLayerAddressOption(test.buf) - if got := sll.EthernetAddress(); got != test.expected { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } -} - -// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the -// Ethernet address from an NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { - tests := []struct { - name string - buf []byte - expected tcpip.LinkAddress - }{ - { - "ValidMAC", - []byte{1, 2, 3, 4, 5, 6}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - { - "TLLBodyTooShort", - []byte{1, 2, 3, 4, 5}, - tcpip.LinkAddress([]byte(nil)), - }, - { - "TLLBodyLargerThanNeeded", - []byte{1, 2, 3, 4, 5, 6, 7, 8}, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tll := NDPTargetLinkLayerAddressOption(test.buf) - if got := tll.EthernetAddress(); got != test.expected { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) - } - }) - } -} - -func TestOpts(t *testing.T) { - const optionHeaderLen = 2 - - checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) { - return func(t *testing.T, opt NDPOption) { - if got := opt.kind(); got != ndpNonceOptionType { - t.Errorf("got kind() = %d, want = %d", got, ndpNonceOptionType) - } - nonce, ok := opt.(NDPNonceOption) - if !ok { - t.Fatalf("got nonce = %T, want = NDPNonceOption", opt) - } - if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" { - t.Errorf("nonce mismatch (-want +got):\n%s", diff) - } - } - } - - checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { - return func(t *testing.T, opt NDPOption) { - if got := opt.kind(); got != ndpTargetLinkLayerAddressOptionType { - t.Errorf("got kind() = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType) - } - tll, ok := opt.(NDPTargetLinkLayerAddressOption) - if !ok { - t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt) - } - if got, want := tll.EthernetAddress(), expectedAddr; got != want { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) - } - } - } - - checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { - return func(t *testing.T, opt NDPOption) { - if got := opt.kind(); got != ndpSourceLinkLayerAddressOptionType { - t.Errorf("got kind() = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType) - } - sll, ok := opt.(NDPSourceLinkLayerAddressOption) - if !ok { - t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt) - } - if got, want := sll.EthernetAddress(), expectedAddr; got != want { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) - } - } - } - - const validLifetimeSeconds = 16909060 - address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718") - - expectedRDNSSBytes := [...]byte{ - // Type, Length - 25, 3, - - // Reserved - 0, 0, - - // Lifetime - 1, 2, 4, 8, - - // Address - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds) - if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize { - t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) - } - // Update reserved fields to non zero values to make sure serializing sets - // them to zero. - rdnssBytes := expectedRDNSSBytes - rdnssBytes[1] = 1 - rdnssBytes[2] = 2 - - const searchListPaddingBytes = 3 - const domainName = "abc.abcd.e" - expectedSearchListBytes := [...]byte{ - // Type, Length - 31, 3, - - // Reserved - 0, 0, - - // Lifetime - 1, 0, 0, 0, - - // Domain names - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - } - binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds) - // Update reserved fields to non zero values to make sure serializing sets - // them to zero. - searchListBytes := expectedSearchListBytes - searchListBytes[2] = 1 - searchListBytes[3] = 2 - - const prefixLength = 43 - const onLinkFlag = false - const slaacFlag = true - const preferredLifetimeSeconds = 84281096 - const onLinkFlagBit = 7 - const slaacFlagBit = 6 - boolToByte := func(v bool) byte { - if v { - return 1 - } - return 0 - } - flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit - expectedPrefixInformationBytes := [...]byte{ - // Type, Length - 3, 4, - - prefixLength, flags, - - // Valid Lifetime - 1, 2, 3, 4, - - // Preferred Lifetime - 5, 6, 7, 8, - - // Reserved2 - 0, 0, 0, 0, - - // Address - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds) - binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds) - if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize { - t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) - } - // Update reserved fields to non zero values to make sure serializing sets - // them to zero. - prefixInformationBytes := expectedPrefixInformationBytes - prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1 - binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1) - tests := []struct { - name string - buf []byte - opt NDPOption - expectedBuf []byte - check func(*testing.T, NDPOption) - }{ - { - name: "Nonce", - buf: make([]byte, 8), - opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}), - expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6}, - check: checkNonce([]byte{1, 2, 3, 4, 5, 6}), - }, - { - name: "Nonce with padding", - buf: []byte{1, 1, 1, 1, 1, 1, 1, 1}, - opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}), - expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0}, - check: checkNonce([]byte{1, 2, 3, 4, 5, 0}), - }, - - { - name: "TLL Ethernet", - buf: make([]byte, 8), - opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), - expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, - check: checkTLL("\x01\x02\x03\x04\x05\x06"), - }, - { - name: "TLL Padding", - buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), - expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - check: checkTLL("\x01\x02\x03\x04\x05\x06"), - }, - { - name: "TLL Empty", - buf: nil, - opt: NDPTargetLinkLayerAddressOption(""), - expectedBuf: nil, - }, - - { - name: "SLL Ethernet", - buf: make([]byte, 8), - opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), - expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, - check: checkSLL("\x01\x02\x03\x04\x05\x06"), - }, - { - name: "SLL Padding", - buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), - expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - check: checkSLL("\x01\x02\x03\x04\x05\x06"), - }, - { - name: "SLL Empty", - buf: nil, - opt: NDPSourceLinkLayerAddressOption(""), - expectedBuf: nil, - }, - - { - name: "RDNSS", - buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - // NDPRecursiveDNSServer holds the option after the header bytes. - opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]), - expectedBuf: expectedRDNSSBytes[:], - check: func(t *testing.T, opt NDPOption) { - if got := opt.kind(); got != ndpRecursiveDNSServerOptionType { - t.Errorf("got kind() = %d, want = %d", got, ndpRecursiveDNSServerOptionType) - } - rdnss, ok := opt.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt) - } - if got, want := rdnss.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { - t.Errorf("got length() = %d, want = %d", got, want) - } - if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want { - t.Errorf("got Lifetime() = %s, want = %s", got, want) - } - if addrs, err := rdnss.Addresses(); err != nil { - t.Errorf("Addresses(): %s", err) - } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - }, - }, - - { - name: "Search list", - buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]), - expectedBuf: expectedSearchListBytes[:], - check: func(t *testing.T, opt NDPOption) { - if got := opt.kind(); got != ndpDNSSearchListOptionType { - t.Errorf("got kind() = %d, want = %d", got, ndpDNSSearchListOptionType) - } - - dnssl, ok := opt.(NDPDNSSearchList) - if !ok { - t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt) - } - if got, want := dnssl.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { - t.Errorf("got length() = %d, want = %d", got, want) - } - if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want { - t.Errorf("got Lifetime() = %s, want = %s", got, want) - } - - if domainNames, err := dnssl.DomainNames(); err != nil { - t.Errorf("DomainNames(): %s", err) - } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" { - t.Errorf("domain names mismatch (-want +got):\n%s", diff) - } - }, - }, - - { - name: "Prefix Information", - buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - // NDPPrefixInformation holds the option after the header bytes. - opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]), - expectedBuf: expectedPrefixInformationBytes[:], - check: func(t *testing.T, opt NDPOption) { - if got := opt.kind(); got != ndpPrefixInformationType { - t.Errorf("got kind() = %d, want = %d", got, ndpPrefixInformationType) - } - - pi, ok := opt.(NDPPrefixInformation) - if !ok { - t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt) - } - - if got, want := pi.length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want { - t.Errorf("got length() = %d, want = %d", got, want) - } - if got := pi.PrefixLength(); got != prefixLength { - t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength) - } - if got := pi.OnLinkFlag(); got != onLinkFlag { - t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag) - } - if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag { - t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag) - } - if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want { - t.Errorf("got ValidLifetime() = %s, want = %s", got, want) - } - if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want { - t.Errorf("got PreferredLifetime() = %s, want = %s", got, want) - } - if got := pi.Prefix(); got != address { - t.Errorf("got Prefix() = %s, want = %s", got, address) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - test.opt, - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length() = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" { - t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next() = (_, true, _), want = (_, false, _)") - } - test.check(t, next) - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next() = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -func TestNDPRecursiveDNSServerOption(t *testing.T) { - tests := []struct { - name string - buf []byte - lifetime time.Duration - addrs []tcpip.Address - }{ - { - "Valid1Addr", - []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - }, - }, - { - "Valid2Addr", - []byte{ - 25, 5, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - }, - }, - { - "Valid3Addr", - []byte{ - 25, 7, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, - 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, - }, - 0, - []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", - "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Iterator should get our option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.kind(); got != ndpRecursiveDNSServerOptionType { - t.Fatalf("got Type = %d, want = %d", got, ndpRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Lifetime(); got != test.lifetime { - t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) - } - addrs, err := opt.Addresses() - if err != nil { - t.Errorf("opt.Addresses() = %s", err) - } - if diff := cmp.Diff(addrs, test.addrs); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - -// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList. -func TestNDPDNSSearchListOption(t *testing.T) { - tests := []struct { - name string - buf []byte - lifetime time.Duration - domainNames []string - err error - }{ - { - name: "Valid1Label", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, 'a', 'b', 'c', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: []string{ - "abc", - }, - err: nil, - }, - { - name: "Valid2Label", - buf: []byte{ - 0, 0, - 0, 0, 0, 5, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 0, - 0, 0, 0, 0, 0, 0, - }, - lifetime: 5 * time.Second, - domainNames: []string{ - "abc.abcd", - }, - err: nil, - }, - { - name: "Valid3Label", - buf: []byte{ - 0, 0, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - }, - lifetime: 16777216 * time.Second, - domainNames: []string{ - "abc.abcd.e", - }, - err: nil, - }, - { - name: "Valid2Domains", - buf: []byte{ - 0, 0, - 1, 2, 3, 4, - 3, 'a', 'b', 'c', - 0, - 2, 'd', 'e', - 3, 'x', 'y', 'z', - 0, - 0, 0, 0, - }, - lifetime: 16909060 * time.Second, - domainNames: []string{ - "abc", - "de.xyz", - }, - err: nil, - }, - { - name: "Valid3DomainsMixedCase", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 3, 'a', 'B', 'c', - 0, - 2, 'd', 'E', - 3, 'X', 'y', 'z', - 0, - 1, 'J', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abc", - "de.xyz", - "j", - }, - err: nil, - }, - { - name: "ValidDomainAfterNULL", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 3, 'a', 'B', 'c', - 0, 0, 0, 0, - 2, 'd', 'E', - 3, 'X', 'y', 'z', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abc", - "de.xyz", - }, - err: nil, - }, - { - name: "Valid0Domains", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 0, - 0, 0, 0, 0, 0, 0, 0, - }, - lifetime: 0, - domainNames: nil, - err: nil, - }, - { - name: "NoTrailingNull", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g', - }, - lifetime: 0, - domainNames: nil, - err: io.ErrUnexpectedEOF, - }, - { - name: "IncorrectLength", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g', - }, - lifetime: 0, - domainNames: nil, - err: io.ErrUnexpectedEOF, - }, - { - name: "IncorrectLengthWithNULL", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 7, 'a', 'b', 'c', 'd', 'e', 'f', - 0, - }, - lifetime: 0, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "LabelOfLength63", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk", - }, - err: nil, - }, - { - name: "LabelOfLength64", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', - 0, - }, - lifetime: 0, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "DomainNameOfLength255", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', - 0, - }, - lifetime: 0, - domainNames: []string{ - "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij", - }, - err: nil, - }, - { - name: "DomainNameOfLength256", - buf: []byte{ - 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 0, - }, - lifetime: 0, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "StartingDigitForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, '9', 'b', 'c', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "StartingHyphenForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, '-', 'b', 'c', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "EndingHyphenForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, 'a', 'b', '-', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: nil, - err: ErrNDPOptMalformedBody, - }, - { - name: "EndingDigitForLabel", - buf: []byte{ - 0, 0, - 0, 0, 0, 1, - 3, 'a', 'b', '9', - 0, - 0, 0, 0, - }, - lifetime: time.Second, - domainNames: []string{ - "ab9", - }, - err: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opt := NDPDNSSearchList(test.buf) - - if got := opt.Lifetime(); got != test.lifetime { - t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) - } - domainNames, err := opt.DomainNames() - if !errors.Is(err, test.err) { - t.Errorf("opt.DomainNames() = %s", err) - } - if diff := cmp.Diff(domainNames, test.domainNames); diff != "" { - t.Errorf("mismatched domain names (-want +got):\n%s", diff) - } - }) - } -} - -func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { - for r := rune(0); r <= 255; r++ { - t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) { - buf := []byte{ - 0, 0, - 0, 0, 0, 0, - 3, 'a', 0 /* will be replaced */, 'c', - 0, - 0, 0, 0, - } - buf[8] = uint8(r) - opt := NDPDNSSearchList(buf) - - // As per RFC 1035 section 2.3.1, the label must only include ASCII - // letters, digits and hyphens (a-z, A-Z, 0-9, -). - var expectedErr error - re := regexp.MustCompile(`[a-zA-Z0-9-]`) - if !re.Match([]byte{byte(r)}) { - expectedErr = ErrNDPOptMalformedBody - } - - if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) { - t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody) - } - }) - } -} - -// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions -// the iterator was returned for is malformed. -func TestNDPOptionsIterCheck(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedErr error - }{ - { - name: "ZeroLengthField", - buf: []byte{0, 0, 0, 0, 0, 0, 0, 0}, - expectedErr: ErrNDPOptMalformedHeader, - }, - { - name: "ValidSourceLinkLayerAddressOption", - buf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, - expectedErr: nil, - }, - { - name: "TooSmallSourceLinkLayerAddressOption", - buf: []byte{1, 1, 1, 2, 3, 4, 5}, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "ValidTargetLinkLayerAddressOption", - buf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, - expectedErr: nil, - }, - { - name: "TooSmallTargetLinkLayerAddressOption", - buf: []byte{2, 1, 1, 2, 3, 4, 5}, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "ValidPrefixInformation", - buf: []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - expectedErr: nil, - }, - { - name: "TooSmallPrefixInformation", - buf: []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "InvalidPrefixInformationLength", - buf: []byte{ - 3, 3, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", - buf: []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - expectedErr: nil, - }, - { - name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", - buf: []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // 255 is an unrecognized type. If 255 ends up - // being the type for some recognized type, - // update 255 to some other unrecognized value. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - }, - expectedErr: nil, - }, - { - name: "InvalidRecursiveDNSServerCutsOffAddress", - buf: []byte{ - 25, 4, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 1, 2, 3, 4, 5, 6, 7, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "InvalidRecursiveDNSServerInvalidLengthField", - buf: []byte{ - 25, 2, 0, 0, - 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "RecursiveDNSServerTooSmall", - buf: []byte{ - 25, 1, 0, 0, - 0, 0, 0, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - { - name: "RecursiveDNSServerMulticast", - buf: []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "RecursiveDNSServerUnspecified", - buf: []byte{ - 25, 3, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "DNSSearchListLargeCompliantRFC1035", - buf: []byte{ - 31, 33, 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', - 0, - }, - expectedErr: nil, - }, - { - name: "DNSSearchListNonCompliantRFC1035", - buf: []byte{ - 31, 33, 0, 0, - 0, 0, 0, 0, - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', - 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', - 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - expectedErr: ErrNDPOptMalformedBody, - }, - { - name: "DNSSearchListValidSmall", - buf: []byte{ - 31, 2, 0, 0, - 0, 0, 0, 0, - 6, 'a', 'b', 'c', 'd', 'e', 'f', - 0, - }, - expectedErr: nil, - }, - { - name: "DNSSearchListTooSmall", - buf: []byte{ - 31, 1, 0, 0, - 0, 0, 0, - }, - expectedErr: io.ErrUnexpectedEOF, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - - if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) { - t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr) - } - - // test.buf may be malformed but we chose not to check - // the iterator so it must return true. - if _, err := opts.Iter(false); err != nil { - t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) - } - }) - } -} - -// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, -// this test does not actually check any of the option's getters, it simply -// checks the option Type and Body. We have other tests that tests the option -// field gettings given an option body and don't need to duplicate those tests -// here. -func TestNDPOptionsIter(t *testing.T) { - buf := []byte{ - // Source Link-Layer Address. - 1, 1, 1, 2, 3, 4, 5, 6, - - // Target Link-Layer Address. - 2, 1, 7, 8, 9, 10, 11, 12, - - // 255 is an unrecognized type. If 255 ends up being the type - // for some recognized type, update 255 to some other - // unrecognized value. Note, this option should be skipped when - // iterating. - 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, - - // Prefix information. - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - opts := NDPOptions(buf) - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - // Test the first (Source Link-Layer) option. - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.kind(); got != ndpSourceLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType) - } - - // Test the next (Target Link-Layer) option. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.kind(); got != ndpTargetLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType) - } - - // Test the next (Prefix Information) option. - // Note, the unrecognized option should be skipped. - next, done, err = it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { - t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - if got := next.kind(); got != ndpPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, ndpPrefixInformationType) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRoutePreferenceStringer(t *testing.T) { - p := NDPRoutePreference(0) - for { - var wantStr string - switch p { - case 0b01: - wantStr = "HighRoutePreference" - case 0b00: - wantStr = "MediumRoutePreference" - case 0b11: - wantStr = "LowRoutePreference" - case 0b10: - wantStr = "ReservedRoutePreference" - default: - wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p) - } - - if gotStr := p.String(); gotStr != wantStr { - t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr) - } - - p++ - if p == 0 { - // Overflowed, we hit all values. - break - } - } -} diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD deleted file mode 100644 index 2adee9288..000000000 --- a/pkg/tcpip/header/parse/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "parse", - srcs = ["parse.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/header/parse/parse_state_autogen.go b/pkg/tcpip/header/parse/parse_state_autogen.go new file mode 100644 index 000000000..ad047be32 --- /dev/null +++ b/pkg/tcpip/header/parse/parse_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package parse diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go deleted file mode 100644 index 96db8460f..000000000 --- a/pkg/tcpip/header/tcp_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestEncodeSACKBlocks(t *testing.T) { - testCases := []struct { - sackBlocks []header.SACKBlock - want []header.SACKBlock - bufSize int - }{ - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 40, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, - 30, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}}, - 20, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}}, - 10, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - nil, - 8, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 60, - }, - } - for _, tc := range testCases { - b := make([]byte, tc.bufSize) - t.Logf("testing: %v", tc) - header.EncodeSACKBlocks(tc.sackBlocks, b) - opts := header.ParseTCPOptions(b) - if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want) - } - } -} - -func TestTCPParseOptions(t *testing.T) { - type tsOption struct { - tsVal uint32 - tsEcr uint32 - } - - generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte { - l := 0 - if tsOpt != nil { - l += 10 - } - if len(sackBlocks) != 0 { - l += len(sackBlocks)*8 + 2 - } - b := make([]byte, l) - offset := 0 - if tsOpt != nil { - offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b) - } - header.EncodeSACKBlocks(sackBlocks, b[offset:]) - return b - } - - testCases := []struct { - b []byte - want header.TCPOptions - }{ - // Trivial cases. - {nil, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test timestamp parsing. - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - - // Test malformed timestamp option. - {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test SACKBlock parsing. - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}}, - {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}}, - - // Test malformed SACK option. - {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}}, - - // Test Timestamp + SACK block parsing. - {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}}, - - // Test valid timestamp + malformed SACK block parsing. - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - } - for _, tc := range testCases { - if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want) - } - } -} - -func TestTCPFlags(t *testing.T) { - for _, tt := range []struct { - flags header.TCPFlags - want string - }{ - {header.TCPFlagFin, "F "}, - {header.TCPFlagSyn, " S "}, - {header.TCPFlagRst, " R "}, - {header.TCPFlagPsh, " P "}, - {header.TCPFlagAck, " A "}, - {header.TCPFlagUrg, " U"}, - {header.TCPFlagSyn | header.TCPFlagAck, " S A "}, - {header.TCPFlagFin | header.TCPFlagAck, "F A "}, - } { - if got := tt.flags.String(); got != tt.want { - t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want) - } - } -} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD deleted file mode 100644 index 973f06cbc..000000000 --- a/pkg/tcpip/link/channel/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "channel", - srcs = ["channel.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/channel/channel_state_autogen.go b/pkg/tcpip/link/channel/channel_state_autogen.go new file mode 100644 index 000000000..7730b59b8 --- /dev/null +++ b/pkg/tcpip/link/channel/channel_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package channel + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (n *NotificationHandle) StateTypeName() string { + return "pkg/tcpip/link/channel.NotificationHandle" +} + +func (n *NotificationHandle) StateFields() []string { + return []string{ + "n", + } +} + +func (n *NotificationHandle) beforeSave() {} + +// +checklocksignore +func (n *NotificationHandle) StateSave(stateSinkObject state.Sink) { + n.beforeSave() + stateSinkObject.Save(0, &n.n) +} + +func (n *NotificationHandle) afterLoad() {} + +// +checklocksignore +func (n *NotificationHandle) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &n.n) +} + +func init() { + state.Register((*NotificationHandle)(nil)) +} diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD deleted file mode 100644 index 0ae0d201a..000000000 --- a/pkg/tcpip/link/ethernet/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ethernet", - srcs = ["ethernet.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ethernet_test", - size = "small", - srcs = ["ethernet_test.go"], - deps = [ - ":ethernet", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/ethernet/ethernet_state_autogen.go b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go new file mode 100644 index 000000000..71d255c20 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ethernet diff --git a/pkg/tcpip/link/ethernet/ethernet_test.go b/pkg/tcpip/link/ethernet/ethernet_test.go deleted file mode 100644 index 08a7f1ce1..000000000 --- a/pkg/tcpip/link/ethernet/ethernet_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ethernet_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil) - -type testNetworkDispatcher struct { - networkPackets int -} - -func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { - t.networkPackets++ -} - -func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -func TestDeliverNetworkPacket(t *testing.T) { - const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - otherLinkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - otherLinkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - ) - - e := ethernet.New(channel.New(0, 0, linkAddr)) - var networkDispatcher testNetworkDispatcher - e.Attach(&networkDispatcher) - - if networkDispatcher.networkPackets != 0 { - t.Fatalf("got networkDispatcher.networkPackets = %d, want = 0", networkDispatcher.networkPackets) - } - - // An ethernet frame with a destination link address that is not assigned to - // our ethernet link endpoint should still be delivered to the network - // dispatcher since the ethernet endpoint is not expected to filter frames. - eth := buffer.NewView(header.EthernetMinimumSize) - header.Ethernet(eth).Encode(&header.EthernetFields{ - SrcAddr: otherLinkAddr1, - DstAddr: otherLinkAddr2, - Type: header.IPv4ProtocolNumber, - }) - e.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: eth.ToVectorisedView(), - })) - if networkDispatcher.networkPackets != 1 { - t.Fatalf("got networkDispatcher.networkPackets = %d, want = 1", networkDispatcher.networkPackets) - } -} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD deleted file mode 100644 index 1d0163823..000000000 --- a/pkg/tcpip/link/fdbased/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdbased", - srcs = [ - "endpoint.go", - "endpoint_unsafe.go", - "mmap.go", - "mmap_stub.go", - "mmap_unsafe.go", - "packet_dispatchers.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdbased_test", - size = "small", - srcs = ["endpoint_test.go"], - library = ":fdbased", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go deleted file mode 100644 index eccd21579..000000000 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ /dev/null @@ -1,624 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build linux -// +build linux - -package fdbased - -import ( - "bytes" - "fmt" - "math/rand" - "reflect" - "testing" - "time" - "unsafe" - - "github.com/google/go-cmp/cmp" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - mtu = 1500 - laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") - raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") - proto = 10 - csumOffset = 48 - gsoMSS = 500 -) - -type packetInfo struct { - Raddr tcpip.LinkAddress - Proto tcpip.NetworkProtocolNumber - Contents *stack.PacketBuffer -} - -type packetContents struct { - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View - Data buffer.View -} - -func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { - t.Helper() - if diff := cmp.Diff( - want, got, - cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { - if pk == nil { - return nil - } - return &packetContents{ - LinkHeader: pk.LinkHeader().View(), - NetworkHeader: pk.NetworkHeader().View(), - TransportHeader: pk.TransportHeader().View(), - Data: pk.Data().AsRange().ToOwnedView(), - } - }), - ); diff != "" { - t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) - } -} - -type context struct { - t *testing.T - readFDs []int - writeFDs []int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} -} - -func newContext(t *testing.T, opt *Options) *context { - firstFDPair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - secondFDPair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - - done := make(chan struct{}, 2) - opt.ClosedFunc = func(tcpip.Error) { - done <- struct{}{} - } - - opt.FDs = []int{firstFDPair[1], secondFDPair[1]} - ep, err := New(opt) - if err != nil { - t.Fatalf("Failed to create FD endpoint: %v", err) - } - - c := &context{ - t: t, - readFDs: []int{firstFDPair[0], secondFDPair[0]}, - writeFDs: opt.FDs, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, - } - - ep.Attach(c) - - return c -} - -func (c *context) cleanup() { - for _, fd := range c.readFDs { - unix.Close(fd) - } - <-c.done - <-c.done - for _, fd := range c.writeFDs { - unix.Close(fd) - } -} - -func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - c.ch <- packetInfo{remote, protocol, pkt} -} - -func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestNoEthernetProperties(t *testing.T) { - c := newContext(t, &Options{MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestEthernetProperties(t *testing.T) { - c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestAddress(t *testing.T) { - addrs := []tcpip.LinkAddress{"", "abc", "def"} - for _, a := range addrs { - t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { - c := newContext(t, &Options{Address: a, MTU: mtu}) - defer c.cleanup() - - if want, v := a, c.ep.LinkAddress(); want != v { - t.Fatalf("LinkAddress() = %v, want %v", v, want) - } - }) - } -} - -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) - defer c.cleanup() - - var r stack.RouteInfo - r.RemoteLinkAddress = raddr - - // Build payload. - payload := buffer.NewView(plen) - if _, err := rand.Read(payload); err != nil { - t.Fatalf("rand.Read(payload): %s", err) - } - - // Build packet buffer. - const netHdrLen = 100 - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, - Data: payload.ToVectorisedView(), - }) - pkt.Hash = hash - - // Build header. - b := pkt.NetworkHeader().Push(netHdrLen) - if _, err := rand.Read(b); err != nil { - t.Fatalf("rand.Read(b): %s", err) - } - - // Write. - want := append(append(buffer.View(nil), b...), payload...) - const l3HdrLen = header.IPv6MinimumSize - if gsoMaxSize != 0 { - pkt.GSOOptions = stack.GSO{ - Type: stack.GSOTCPv6, - NeedsCsum: true, - CsumOffset: csumOffset, - MSS: gsoMSS, - L3HdrLen: l3HdrLen, - } - } - if err := c.ep.WritePacket(r, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the corresponding FD, then compare with what we wrote. - b = make([]byte, mtu) - fd := c.readFDs[hash%uint32(len(c.readFDs))] - n, err := unix.Read(fd, b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - if gsoMaxSize != 0 { - vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) - if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { - t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) - } - const csumStart = header.EthernetMinimumSize + l3HdrLen - if vnetHdr.csumStart != csumStart { - t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) - } - if vnetHdr.csumOffset != csumOffset { - t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) - } - gsoType := uint8(0) - if plen > gsoMSS { - gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - } - if vnetHdr.gsoType != gsoType { - t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) - } - b = b[virtioNetHdrSize:] - } - if eth { - h := header.Ethernet(b) - b = b[header.EthernetMinimumSize:] - - if a := h.SourceAddress(); a != laddr { - t.Fatalf("SourceAddress() = %v, want %v", a, laddr) - } - - if a := h.DestinationAddress(); a != raddr { - t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) - } - - if et := h.Type(); et != proto { - t.Fatalf("Type() = %v, want %v", et, proto) - } - } - if len(b) != len(want) { - t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) - } - if !bytes.Equal(b, want) { - t.Fatalf("Read returned %x, want %x", b, want) - } -} - -func TestWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso, 0) - }, - ) - } - } - } -} - -func TestHashedWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - hashes := []uint32{0, 1} - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - for _, hash := range hashes { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso, hash) - }, - ) - } - } - } - } -} - -func TestPreserveSrcAddress(t *testing.T) { - baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") - - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) - defer c.cleanup() - - // Set LocalLinkAddress in route to the value of the bridged address. - var r stack.RouteInfo - r.LocalLinkAddress = baddr - r.RemoteLinkAddress = raddr - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). - ReserveHeaderBytes: header.EthernetMinimumSize, - Data: buffer.VectorisedView{}, - }) - if err := c.ep.WritePacket(r, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the FD, then compare with what we wrote. - b := make([]byte, mtu) - n, err := unix.Read(c.readFDs[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - h := header.Ethernet(b) - - if a := h.SourceAddress(); a != baddr { - t.Fatalf("SourceAddress() = %v, want %v", a, baddr) - } -} - -func TestDeliverPacket(t *testing.T) { - lengths := []int{100, 1000} - eths := []bool{true, false} - - for _, eth := range eths { - for _, plen := range lengths { - t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) - defer c.cleanup() - - // Build packet. - all := make([]byte, plen) - if _, err := rand.Read(all); err != nil { - t.Fatalf("rand.Read(all): %s", err) - } - // Make it look like an IPv4 packet. - all[0] = 0x40 - - wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.EthernetMinimumSize, - Data: buffer.NewViewFromBytes(all).ToVectorisedView(), - }) - if eth { - hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) - hdr.Encode(&header.EthernetFields{ - SrcAddr: raddr, - DstAddr: laddr, - Type: proto, - }) - all = append(hdr, all...) - } - - // Write packet via the file descriptor. - if _, err := unix.Write(c.readFDs[0], all); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Receive packet through the endpoint. - select { - case pi := <-c.ch: - want := packetInfo{ - Raddr: raddr, - Proto: proto, - Contents: wantPkt, - } - if !eth { - want.Proto = header.IPv4ProtocolNumber - want.Raddr = "" - } - checkPacketInfoEqual(t, pi, want) - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for packet") - } - }) - } - } -} - -func TestBufConfigMaxLength(t *testing.T) { - got := 0 - for _, i := range BufConfig { - got += i - } - want := header.MaxIPPacketSize // maximum TCP packet size - if got < want { - t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) - } -} - -func TestBufConfigFirst(t *testing.T) { - // The stack assumes that the TCP/IP header is enterily contained in the first view. - // Therefore, the first view needs to be large enough to contain the maximum TCP/IP - // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). - want := 120 - got := BufConfig[0] - if got < want { - t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) - } -} - -var capLengthTestCases = []struct { - comment string - config []int - n int - wantUsed int - wantLengths []int -}{ - { - comment: "Single slice", - config: []int{2}, - n: 1, - wantUsed: 1, - wantLengths: []int{1}, - }, - { - comment: "Multiple slices", - config: []int{1, 2}, - n: 2, - wantUsed: 2, - wantLengths: []int{1, 1}, - }, - { - comment: "Entire buffer", - config: []int{1, 2}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, - { - comment: "Entire buffer but not on the last slice", - config: []int{1, 2, 3}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, -} - -func TestIovecBuffer(t *testing.T) { - for _, c := range capLengthTestCases { - t.Run(c.comment, func(t *testing.T) { - b := newIovecBuffer(c.config, false /* skipsVnetHdr */) - - // Test initial allocation. - iovecs := b.nextIovecs() - if got, want := len(iovecs), len(c.config); got != want { - t.Fatalf("len(iovecs) = %d, want %d", got, want) - } - - // Make a copy as iovecs points to internal slice. We will need this state - // later. - oldIovecs := append([]unix.Iovec(nil), iovecs...) - - // Test the views that get pulled. - vv := b.pullViews(c.n) - var lengths []int - for _, v := range vv.Views() { - lengths = append(lengths, len(v)) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) - } - - // Test that new views get reallocated. - for i, newIov := range b.nextIovecs() { - if i < c.wantUsed { - if newIov.Base == oldIovecs[i].Base { - t.Errorf("b.views[%d] should have been reallocated", i) - } - } else { - if newIov.Base != oldIovecs[i].Base { - t.Errorf("b.views[%d] should not have been reallocated", i) - } - } - } - }) - } -} - -func TestIovecBufferSkipVnetHdr(t *testing.T) { - for _, test := range []struct { - desc string - readN int - wantLen int - }{ - { - desc: "nothing read", - readN: 0, - wantLen: 0, - }, - { - desc: "smaller than vnet header", - readN: virtioNetHdrSize - 1, - wantLen: 0, - }, - { - desc: "header skipped", - readN: virtioNetHdrSize + 100, - wantLen: 100, - }, - } { - t.Run(test.desc, func(t *testing.T) { - b := newIovecBuffer([]int{10, 20, 50, 50}, true) - // Pretend a read happend. - b.nextIovecs() - vv := b.pullViews(test.readN) - if got, want := vv.Size(), test.wantLen; got != want { - t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) - } - if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { - t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) - } - }) - } -} - -// fakeNetworkDispatcher delivers packets to pkts. -type fakeNetworkDispatcher struct { - pkts []*stack.PacketBuffer -} - -func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - d.pkts = append(d.pkts, pkt) -} - -func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestDispatchPacketFormat(t *testing.T) { - for _, test := range []struct { - name string - newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) - }{ - { - name: "readVDispatcher", - newDispatcher: newReadVDispatcher, - }, - { - name: "recvMMsgDispatcher", - newDispatcher: newRecvMMsgDispatcher, - }, - } { - t.Run(test.name, func(t *testing.T) { - // Create a socket pair to send/recv. - fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_DGRAM, 0) - if err != nil { - t.Fatal(err) - } - defer unix.Close(fds[0]) - defer unix.Close(fds[1]) - - data := []byte{ - // Ethernet header. - 1, 2, 3, 4, 5, 60, - 1, 2, 3, 4, 5, 61, - 8, 0, - // Mock network header. - 40, 41, 42, 43, - } - err = unix.Sendmsg(fds[1], data, nil, nil, 0) - if err != nil { - t.Fatal(err) - } - - // Create and run dispatcher once. - sink := &fakeNetworkDispatcher{} - d, err := test.newDispatcher(fds[0], &endpoint{ - hdrSize: header.EthernetMinimumSize, - dispatcher: sink, - }) - if err != nil { - t.Fatal(err) - } - if ok, err := d.dispatch(); !ok || err != nil { - t.Fatalf("d.dispatch() = %v, %v", ok, err) - } - - // Verify packet. - if got, want := len(sink.pkts), 1; got != want { - t.Fatalf("len(sink.pkts) = %d, want %d", got, want) - } - pkt := sink.pkts[0] - if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { - t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) - } - if got, want := pkt.Data().Size(), 4; got != want { - t.Errorf("pkt.Data().Size() = %d, want %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go new file mode 100644 index 000000000..586f166a4 --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go @@ -0,0 +1,9 @@ +// automatically generated by stateify. + +//go:build linux && ((linux && amd64) || (linux && arm64)) && (!linux || (!amd64 && !arm64)) && linux +// +build linux +// +build linux,amd64 linux,arm64 +// +build !linux !amd64,!arm64 +// +build linux + +package fdbased diff --git a/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go new file mode 100644 index 000000000..6a5ed4a3c --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go @@ -0,0 +1,7 @@ +// automatically generated by stateify. + +//go:build linux && ((linux && amd64) || (linux && arm64)) +// +build linux +// +build linux,amd64 linux,arm64 + +package fdbased diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD deleted file mode 100644 index 6bf3805b7..000000000 --- a/pkg/tcpip/link/loopback/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "loopback", - srcs = ["loopback.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go new file mode 100644 index 000000000..c00fd9f19 --- /dev/null +++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package loopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD deleted file mode 100644 index 193524525..000000000 --- a/pkg/tcpip/link/muxed/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "muxed", - srcs = ["injectable.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "muxed_test", - size = "small", - srcs = ["injectable_test.go"], - library = ":muxed", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go deleted file mode 100644 index 040e3a35b..000000000 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package muxed - -import ( - "bytes" - "net" - "os" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -func TestInjectableEndpointRawDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - endpoint.InjectOutbound(dstIP, []byte{0xFA}) - - buf := make([]byte, ipv4.MaxTotalSize) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: 1, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) - pkt.TransportHeader().Push(1)[0] = 0xFA - var packetRoute stack.RouteInfo - packetRoute.RemoteAddress = dstIP - - endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt) - - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: 1, - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(1)[0] = 0xFA - var packetRoute stack.RouteInfo - packetRoute.RemoteAddress = dstIP - endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt) - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) { - dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4()) - pair, err := unix.Socketpair(unix.AF_UNIX, - unix.SOCK_SEQPACKET|unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK, 0) - if err != nil { - t.Fatal("Failed to create socket pair:", err) - } - underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) - routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - endpoint := NewInjectableEndpoint(routes) - return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP -} diff --git a/pkg/tcpip/link/muxed/muxed_state_autogen.go b/pkg/tcpip/link/muxed/muxed_state_autogen.go new file mode 100644 index 000000000..56330e2a5 --- /dev/null +++ b/pkg/tcpip/link/muxed/muxed_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package muxed diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD deleted file mode 100644 index 00b42b924..000000000 --- a/pkg/tcpip/link/nested/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "nested", - srcs = [ - "nested.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "nested_test", - size = "small", - srcs = [ - "nested_test.go", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/nested/nested_state_autogen.go b/pkg/tcpip/link/nested/nested_state_autogen.go new file mode 100644 index 000000000..9e1b5ca4e --- /dev/null +++ b/pkg/tcpip/link/nested/nested_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package nested diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go deleted file mode 100644 index c1f9d308c..000000000 --- a/pkg/tcpip/link/nested/nested_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package nested_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/nested" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type parentEndpoint struct { - nested.Endpoint -} - -var _ stack.LinkEndpoint = (*parentEndpoint)(nil) -var _ stack.NetworkDispatcher = (*parentEndpoint)(nil) - -type childEndpoint struct { - stack.LinkEndpoint - dispatcher stack.NetworkDispatcher -} - -var _ stack.LinkEndpoint = (*childEndpoint)(nil) - -func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - c.dispatcher = dispatcher -} - -func (c *childEndpoint) IsAttached() bool { - return c.dispatcher != nil -} - -type counterDispatcher struct { - count int -} - -var _ stack.NetworkDispatcher = (*counterDispatcher)(nil) - -func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { - d.count++ -} - -func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestNestedLinkEndpoint(t *testing.T) { - const emptyAddress = tcpip.LinkAddress("") - - var ( - childEP childEndpoint - nestedEP parentEndpoint - disp counterDispatcher - ) - nestedEP.Endpoint.Init(&childEP, &nestedEP) - - if childEP.IsAttached() { - t.Error("On init, childEP.IsAttached() = true, want = false") - } - if nestedEP.IsAttached() { - t.Error("On init, nestedEP.IsAttached() = true, want = false") - } - - nestedEP.Attach(&disp) - if disp.count != 0 { - t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count) - } - if !childEP.IsAttached() { - t.Error("After attach, childEP.IsAttached() = false, want = true") - } - if !nestedEP.IsAttached() { - t.Error("After attach, nestedEP.IsAttached() = false, want = true") - } - - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if disp.count != 1 { - t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) - } - - nestedEP.Attach(nil) - if childEP.IsAttached() { - t.Error("After detach, childEP.IsAttached() = true, want = false") - } - if nestedEP.IsAttached() { - t.Error("After detach, nestedEP.IsAttached() = true, want = false") - } - - disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if disp.count != 0 { - t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) - } - -} diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD deleted file mode 100644 index 6fff160ce..000000000 --- a/pkg/tcpip/link/packetsocket/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "packetsocket", - srcs = ["endpoint.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/packetsocket/packetsocket_state_autogen.go b/pkg/tcpip/link/packetsocket/packetsocket_state_autogen.go new file mode 100644 index 000000000..6b3221fd8 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/packetsocket_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package packetsocket diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD deleted file mode 100644 index 9f31c1ffc..000000000 --- a/pkg/tcpip/link/pipe/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = ["pipe.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/pipe/pipe_state_autogen.go b/pkg/tcpip/link/pipe/pipe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/pipe/pipe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD deleted file mode 100644 index 5bea598eb..000000000 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fifo", - srcs = [ - "endpoint.go", - "packet_buffer_queue.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go b/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go new file mode 100644 index 000000000..9eb52b1cb --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package fifo diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD deleted file mode 100644 index 4efd7c45e..000000000 --- a/pkg/tcpip/link/rawfile/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "rawfile", - srcs = [ - "blockingpoll_amd64.s", - "blockingpoll_arm64.s", - "blockingpoll_noyield_unsafe.go", - "blockingpoll_yield_unsafe.go", - "errors.go", - "rawfile_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "rawfile_test", - srcs = [ - "errors_test.go", - ], - library = "rawfile", - deps = [ - "//pkg/tcpip", - "@com_github_google_go_cmp//cmp:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go deleted file mode 100644 index 1b88c309b..000000000 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build linux -// +build linux - -package rawfile - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestTranslateErrno(t *testing.T) { - for _, test := range []struct { - errno unix.Errno - translated tcpip.Error - }{ - { - errno: unix.Errno(0), - translated: &tcpip.ErrInvalidEndpointState{}, - }, - { - errno: unix.Errno(maxErrno), - translated: &tcpip.ErrInvalidEndpointState{}, - }, - { - errno: unix.Errno(514), - translated: &tcpip.ErrInvalidEndpointState{}, - }, - { - errno: unix.EEXIST, - translated: &tcpip.ErrDuplicateAddress{}, - }, - } { - got := TranslateErrno(test.errno) - if diff := cmp.Diff(test.translated, got); diff != "" { - t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff) - } - } -} diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go new file mode 100644 index 000000000..00708246f --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package rawfile diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go new file mode 100644 index 000000000..c42f3a3b6 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go @@ -0,0 +1,11 @@ +// automatically generated by stateify. + +//go:build linux && !amd64 && !arm64 && ((linux && amd64) || (linux && arm64)) && go1.12 && linux +// +build linux +// +build !amd64 +// +build !arm64 +// +build linux,amd64 linux,arm64 +// +build go1.12 +// +build linux + +package rawfile diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD deleted file mode 100644 index 4215ee852..000000000 --- a/pkg/tcpip/link/sharedmem/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sharedmem", - srcs = [ - "rx.go", - "sharedmem.go", - "sharedmem_unsafe.go", - "tx.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], - library = ":sharedmem", - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/sharedmem/pipe", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD deleted file mode 100644 index 87020ec08..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "pipe.go", - "pipe_unsafe.go", - "rx.go", - "tx.go", - ], - visibility = ["//visibility:public"], -) - -go_test( - name = "pipe_test", - srcs = [ - "pipe_test.go", - ], - library = ":pipe", - deps = ["//pkg/sync"], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go deleted file mode 100644 index 2777f1411..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ /dev/null @@ -1,512 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "math/rand" - "reflect" - "runtime" - "testing" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestSimpleReadWrite(t *testing.T) { - // Check that a simple write can be properly read from the rx side. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - wb := tx.Push(10) - if wb == nil { - t.Fatalf("Push failed on empty pipe") - } - for i := range wb { - wb[i] = byte(tr.Intn(256)) - } - tx.Flush() - - var rx Rx - rx.Init(b) - rb := rx.Pull() - if len(rb) != 10 { - t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10) - } - - for i := range rb { - if v := byte(rr.Intn(256)); v != rb[i] { - t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v) - } - } - rx.Flush() -} - -func TestEmptyRead(t *testing.T) { - // Check that pulling from an empty pipe fails. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestTooLargeWrite(t *testing.T) { - // Check that writes that are too large are properly rejected. - b := make([]byte, 96) - var tx Tx - tx.Init(b) - - if wb := tx.Push(96); wb != nil { - t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(88); wb != nil { - t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } -} - -func TestFullWrite(t *testing.T) { - // Check that writes fail when the pipe is full. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestFullAndFlushedWrite(t *testing.T) { - // Check that writes fail when the pipe is full and has already been - // flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - tx.Flush() - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestTxFlushTwice(t *testing.T) { - // Checks that a second consecutive tx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // Make copy of original tx queue, flush it, then check that it didn't - // change. - orig := tx - tx.Flush() - - if !reflect.DeepEqual(orig, tx) { - t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig) - } -} - -func TestRxFlushTwice(t *testing.T) { - // Checks that a second consecutive rx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // Make copy of original rx queue, flush it, then check that it didn't - // change. - orig := rx - rx.Flush() - - if !reflect.DeepEqual(orig, rx) { - t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig) - } -} - -func TestWrapInMiddleOfTransaction(t *testing.T) { - // Check that writes are not flushed when we need to wrap the buffer - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestWriteAbort(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it but - // has aborted the push. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Abort() - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestWrappedWriteAbort(t *testing.T) { - // Check that writes are properly aborted even if the writes wrap - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Abort() - - // The pushes were aborted, so no data should be readable. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - // Try the same transactions again, but flush this time. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestEmptyReadOnNonFlushedWrite(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it - // but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Flush() - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull on failed on non-empty pipe") - } -} - -func TestPullAfterPullingEntirePipe(t *testing.T) { - // Check that Pull fails when the pipe is full, but all of it has - // already been pulled but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3 - // buffers that will fill the pipe. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - if wb := tx.Push(24); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The three buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - // Fourth pull must fail. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestNoRoomToWrapOnPush(t *testing.T) { - // Check that Push fails when it tries to allocate room to add a wrap - // message. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20, - // which won't fit (64+20+8+padding = 96, which wouldn't leave room for - // the padding), so it wraps around. - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - tx.Flush() - - // Buffer offset is at 28. Try to write 70, which would require a wrap - // slot which cannot be created now. - if wb := tx.Push(70); wb != nil { - t.Fatalf("Push succeeded on pipe with no room for wrap message") - } -} - -func TestRxImplicitFlushOfWrapMessage(t *testing.T) { - // Check if the first read is that of a wrapping message, that it gets - // immediately flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // This will cause a wrapping message to written. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - var rx Rx - rx.Init(b) - - // Read the first message. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // This should fail because of the wrapping message is taking up space. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - // Try to read the next one. This should consume the wrapping message. - rx.Pull() - - // This must now succeed. - if wb := tx.Push(60); wb == nil { - t.Fatalf("Push failed on empty pipe") - } -} - -func TestConcurrentReaderWriter(t *testing.T) { - // Push a million buffers of random sizes and random contents. Check - // that buffers read match what was written. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - - const count = 1000000 - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + tr.Intn(80) - wb := tx.Push(uint64(n)) - for wb == nil { - wb = tx.Push(uint64(n)) - } - - for j := range wb { - wb[j] = byte(tr.Intn(256)) - } - - tx.Flush() - } - }() - - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } - - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } - - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } - } - - rx.Flush() - } -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go new file mode 100644 index 000000000..d3b40feb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package pipe diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD deleted file mode 100644 index 3ba06af73..000000000 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "queue", - srcs = [ - "rx.go", - "tx.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip/link/sharedmem/pipe", - ], -) - -go_test( - name = "queue_test", - srcs = [ - "queue_test.go", - ], - library = ":queue", - deps = [ - "//pkg/tcpip/link/sharedmem/pipe", - ], -) diff --git a/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go new file mode 100644 index 000000000..563d4fbb4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package queue diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go deleted file mode 100644 index 9a0aad5d7..000000000 --- a/pkg/tcpip/link/sharedmem/queue/queue_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -func TestBasicTxQueue(t *testing.T) { - // Tests that a basic transmit on a queue works, and that completion - // gets properly reported as well. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Enqueue two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on empty queue") - } - - // Check the contents of the pipe. - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - - want := []byte{ - 234, 3, 0, 0, 0, 0, 0, 0, // id - 100, 0, 0, 0, // total size - 0, 0, 0, 0, // reserved - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - } - - if !reflect.DeepEqual(want, d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want) - } - - rxp.Flush() - - // Check that there are no completions yet. - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Packet reported as completed too soon") - } - - // Post a completion. - d = txp.Push(8) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - - // Check that completion is properly reported. - id, ok := q.CompletedPacket() - if !ok { - t.Fatalf("Completion not reported") - } - - if id != usedID { - t.Fatalf("Bad completion id: got %v, want %v", id, usedID) - } -} - -func TestBasicRxQueue(t *testing.T) { - // Tests that a basic receive on a queue works. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on empty queue") - } - - // Check the contents of the pipe. - want := [][]byte{ - { - 100, 0, 0, 0, 0, 0, 0, 0, // Offset1 - 60, 0, 0, 0, // Size1 - 0, 0, 0, 0, // Remaining in group 1 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - }, - { - 200, 0, 0, 0, 0, 0, 0, 0, // Offset2 - 40, 0, 0, 0, // Size2 - 0, 0, 0, 0, // Remaining in group 2 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }, - } - - for i := range b { - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - - if !reflect.DeepEqual(want[i], d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want[i]) - } - - rxp.Flush() - } - - // Check that there are no completions. - if _, n := q.Dequeue(nil); n != 0 { - t.Fatalf("Packet reported as received too soon") - } - - // Post a completion. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - // Check that completion is properly reported. - bufs, n := q.Dequeue(nil) - if n != 100 { - t.Fatalf("Bad packet size: got %v, want %v", n, 100) - } - - if !reflect.DeepEqual(bufs, b) { - t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b) - } -} - -func TestBadTxCompletion(t *testing.T) { - // Check that tx completions with bad sizes are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion that is too long, and check that it is ignored. - if d := txp.Push(10); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } -} - -func TestBadRxCompletion(t *testing.T) { - // Check that bad rx completions are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes add up to less than the total - // size. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 10, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 10, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes will cause a 32-bit overflow, - // but adds up to the right number. - d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 255, 255, 255, 255, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 101, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } -} - -func TestFillTxPipe(t *testing.T) { - // Check that transmitting a new buffer when the buffer pipe is full - // fails gracefully. - pb1 := make([]byte, 104) - pb2 := make([]byte, 104) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Transmit twice, which should fill the tx pipe. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - for i := uint64(0); i < 2; i++ { - if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) { - t.Fatalf("Failed to transmit buffer") - } - } - - // Transmit another packet now that the tx pipe is full. - if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue succeeded when tx pipe is full") - } -} - -func TestFillRxPipe(t *testing.T) { - // Check that posting a new buffer when the buffer pipe is full fails - // gracefully. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a buffer twice, it should fill the tx pipe. - b := []RxBuffer{ - {100, 60, 1077, 0}, - } - - for i := 0; i < 2; i++ { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - } - - // Post another buffer now that the tx pipe is full. - if q.PostBuffers(b) { - t.Fatalf("PostBuffers succeeded on full queue") - } -} - -func TestLotsOfTransmissions(t *testing.T) { - // Make sure pipes are being properly flushed when transmitting packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Prepare packet with two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - - // Post 100000 packets and completions. - for i := 100000; i > 0; i-- { - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - rxp.Flush() - - d := txp.Push(8) - if d == nil { - t.Fatalf("Unable to write to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - if _, ok := q.CompletedPacket(); !ok { - t.Fatalf("Completion not returned") - } - } -} - -func TestLotsOfReceptions(t *testing.T) { - // Make sure pipes are being properly flushed when receiving packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Prepare for posting two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - // Post 100000 buffers and completions. - for i := 100000; i > 0; i-- { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - if _, n := q.Dequeue(nil); n == 0 { - t.Fatalf("Dequeue failed when there is a completion") - } - } -} - -func TestRxEnableNotification(t *testing.T) { - // Check that enabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.EnableNotification() - if state != eventFDEnabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled) - } -} - -func TestRxDisableNotification(t *testing.T) { - // Check that disabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.DisableNotification() - if state != eventFDDisabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled) - } -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go new file mode 100644 index 000000000..86551c9f5 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux && linux +// +build linux,linux + +package sharedmem diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go deleted file mode 100644 index d6d953085..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ /dev/null @@ -1,815 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build linux -// +build linux - -package sharedmem - -import ( - "bytes" - "io/ioutil" - "math/rand" - "os" - "strings" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - localLinkAddr = "\xde\xad\xbe\xef\x56\x78" - remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" - - queueDataSize = 1024 * 1024 - queuePipeSize = 4096 -) - -type queueBuffers struct { - data []byte - rx pipe.Tx - tx pipe.Rx -} - -func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) { - // Prepare tx pipe. - b, err := getBuffer(c.TxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.tx.Init(b) - - // Prepare rx pipe. - b, err = getBuffer(c.RxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.rx.Init(b) - - // Get data slice. - q.data, err = getBuffer(c.DataFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } -} - -func (q *queueBuffers) cleanup() { - unix.Munmap(q.tx.Bytes()) - unix.Munmap(q.rx.Bytes()) - unix.Munmap(q.data) -} - -type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - data buffer.View - linkHeader buffer.View -} - -type testContext struct { - t *testing.T - ep *endpoint - txCfg QueueConfig - rxCfg QueueConfig - txq queueBuffers - rxq queueBuffers - - packetCh chan struct{} - mu sync.Mutex - packets []packetInfo -} - -func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext { - var err error - c := &testContext{ - t: t, - packetCh: make(chan struct{}, 1000000), - } - c.txCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - c.rxCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - initQueue(t, &c.txq, &c.txCfg) - initQueue(t, &c.rxq, &c.rxCfg) - - ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - c.ep = ep.(*endpoint) - c.ep.Attach(c) - - return c -} - -func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - c.mu.Lock() - c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - data: pkt.Data().AsRange().ToOwnedView(), - }) - c.mu.Unlock() - - c.packetCh <- struct{}{} -} - -func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func (c *testContext) cleanup() { - c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) - c.txq.cleanup() - c.rxq.cleanup() -} - -func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) { - for i := 0; i < n; i++ { - select { - case <-c.packetCh: - case <-to: - c.t.Fatalf(errorStr) - } - } -} - -func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) { - b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs))) - queue.EncodeRxCompletion(b, size, 0) - for i := range bs { - queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{ - Offset: bs[i].Offset, - Size: bs[i].Size, - ID: bs[i].ID, - }) - } -} - -func randomFill(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func shuffle(b []int) { - for i := len(b) - 1; i >= 0; i-- { - j := rand.Intn(i + 1) - b[i], b[j] = b[j], b[i] - } -} - -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir, ok := os.LookupEnv("TEST_TMPDIR") - if !ok { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - unix.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := unix.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := unix.Ftruncate(fd, size); err != nil { - unix.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - unix.Close(c.DataFD) - unix.Close(c.EventFD) - unix.Close(c.TxPipeFD) - unix.Close(c.RxPipeFD) - unix.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - -// TestSimpleSend sends 1000 packets with random header and payload sizes, -// then checks that the right payload is received on the shared memory queues. -func TestSimpleSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare route. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - for iters := 1000; iters > 0; iters-- { - func() { - hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) - - // Prepare and send packet. - hdrBuf := buffer.NewView(hdrLen) - randomFill(hdrBuf) - - data := buffer.NewView(dataLen) - randomFill(data) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), - Data: data.ToVectorisedView(), - }) - copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check the ethernet header. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: localLinkAddr, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } - - // Compare contents skipping the ethernet header added by the - // endpoint. - merged := append(hdrBuf, data...) - if uint32(len(contents)) < pi.Size { - t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) - } - contents = contents[:pi.Size][header.EthernetMinimumSize:] - - if !bytes.Equal(contents, merged) { - t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) - } - }() - } -} - -// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress -// set in Route (using much of the same code as TestSimpleSend), then checks -// that the encoded ethernet header received includes the correct SrcAddr. -func TestPreserveSrcAddressInSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) - // Set both remote and local link address in route. - var r stack.RouteInfo - r.LocalLinkAddress = newLocalLinkAddress - r.RemoteLinkAddress = remoteLinkAddr - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - ReserveHeaderBytes: header.EthernetMinimumSize, - }) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, proto, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check that the ethernet header contains the expected SrcAddr. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: newLocalLinkAddress, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } -} - -// TestFillTxQueue sends packets until the queue is full. -func TestFillTxQueue(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - - if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } -} - -// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets -// until the queue is full. -func TestFillTxQueueAfterBadCompletion(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Send a bad completion. - queue.EncodeTxCompletion(c.txq.rx.Push(8), 1) - c.txq.rx.Flush() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Send two packets so that the id slice has at least two slots. - for i := 2; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } - - // Complete the two writes twice. - for i := 2; i > 0; i-- { - pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull()) - - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - c.txq.rx.Flush() - } - c.txq.tx.Flush() - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } -} - -// TestFillTxMemory sends packets until the we run out of shared memory. -func TestFillTxMemory(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible until - // we fill the memory. - ids := make(map[uint64]struct{}) - for i := queueDataSize / bufferSize; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - c.txq.tx.Flush() - } - - // Next attempt to write must fail. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } -} - -// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of -// shared memory for a 2-buffer packet, but still with room for a 1-buffer -// packet. -func TestFillTxMemoryWithMultiBuffer(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - var r stack.RouteInfo - r.RemoteLinkAddress = remoteLinkAddr - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible - // until there is only one buffer left. - for i := queueDataSize/bufferSize - 1; i > 0; i-- { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Pull the posted buffer. - c.txq.tx.Pull() - c.txq.tx.Flush() - } - - // Attempt to write a two-buffer packet. It must fail. - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buffer.NewView(bufferSize).ToVectorisedView(), - }) - err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) - } - } - - // Attempt to write the one-buffer packet again. It must succeed. - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), - Data: buf.ToVectorisedView(), - }) - if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } -} - -func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { - t.Helper() - - for { - b := p.Pull() - if b != nil { - return b - } - - select { - case <-time.After(10 * time.Millisecond): - case <-to: - t.Fatal(errStr) - } - } -} - -// TestSimpleReceive completes 1000 different receives with random payload and -// random number of buffers. It checks that the contents match the expected -// values. -func TestSimpleReceive(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Check that buffers have been posted. - limit := c.ep.rx.q.PostedBuffersLimit() - for i := uint64(0); i < limit; i++ { - timeout := time.After(2 * time.Second) - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted")) - - if want := i * bufferSize; want != bi.Offset { - t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want) - } - - if want := i; want != bi.ID { - t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want) - } - - if bufferSize != bi.Size { - t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize) - } - } - c.rxq.tx.Flush() - - // Create a slice with the indices 0..limit-1. - idx := make([]int, limit) - for i := range idx { - idx[i] = i - } - - // Complete random packets 1000 times. - for iters := 1000; iters > 0; iters-- { - timeout := time.After(2 * time.Second) - // Prepare a random packet. - shuffle(idx) - n := 1 + rand.Intn(10) - bufs := make([]queue.RxBuffer, n) - contents := make([]byte, bufferSize*n-rand.Intn(500)) - randomFill(contents) - for i := range bufs { - j := idx[i] - bufs[i].Size = bufferSize - bufs[i].Offset = uint64(bufferSize * j) - bufs[i].ID = uint64(j) - - copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:]) - } - - // Push completion. - c.pushRxCompletion(uint32(len(contents)), bufs) - c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be received, then check it. - c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") - c.mu.Lock() - rcvd := []byte(c.packets[0].data) - c.packets = c.packets[:0] - c.mu.Unlock() - - if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) { - t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) - } - - // Check that buffers have been reposted. - for i := range bufs { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted")) - if bi != bufs[i] { - t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i]) - } - } - c.rxq.tx.Flush() - } -} - -// TestRxBuffersReposted tests that rx buffers get reposted after they have been -// completed. -func TestRxBuffersReposted(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Receive all posted buffers. - limit := c.ep.rx.q.PostedBuffersLimit() - buffers := make([]queue.RxBuffer, 0, limit) - for i := limit; i > 0; i-- { - timeout := time.After(2 * time.Second) - buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers"))) - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when individually completed. - for i := range buffers { - timeout := time.After(2 * time.Second) - // Complete the buffer. - c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) - c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for it to be reposted. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[i] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i]) - } - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when completed in pairs. - for i := 0; i < len(buffers)/2; i++ { - timeout := time.After(2 * time.Second) - // Complete with two buffers. - c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) - c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for them to be reposted. - for j := 0; j < 2; j++ { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[2*i+j] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j]) - } - } - } - c.rxq.tx.Flush() -} - -// TestReceivePostingIsFull checks that the endpoint will properly handle the -// case when a received buffer cannot be immediately reposted because it hasn't -// been pulled from the tx pipe yet. -func TestReceivePostingIsFull(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Complete first posted buffer before flushing it from the tx pipe. - first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) - c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) - c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that packet is received. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Complete another buffer. - second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) - c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) - c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that no packet is received yet, as the worker is blocked trying - // to repost. - select { - case <-time.After(500 * time.Millisecond): - case <-c.packetCh: - t.Fatalf("Unexpected packet received") - } - - // Flush tx queue, which will allow the first buffer to be reposted, - // and the second completion to be pulled. - c.rxq.tx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that second packet completes. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") -} - -// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to -// repost a buffer. Make sure it backs out. -func TestCloseWhileWaitingToPost(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - cleaned := false - defer func() { - if !cleaned { - c.cleanup() - } - }() - - // Complete first posted buffer before flushing it from the tx pipe. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) - c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) - c.rxq.rx.Flush() - unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be indicated. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Cleanup and wait for worker to complete. - c.cleanup() - cleaned = true - c.ep.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go new file mode 100644 index 000000000..ac3a66520 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sharedmem diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD deleted file mode 100644 index 4aac12a8c..000000000 --- a/pkg/tcpip/link/sniffer/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sniffer", - srcs = [ - "pcap.go", - "sniffer.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go new file mode 100644 index 000000000..8d79defea --- /dev/null +++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package sniffer diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD deleted file mode 100644 index c3e4c3455..000000000 --- a/pkg/tcpip/link/tun/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tun_endpoint_refs", - out = "tun_endpoint_refs.go", - package = "tun", - prefix = "tunEndpoint", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "tunEndpoint", - }, -) - -go_library( - name = "tun", - srcs = [ - "device.go", - "protocol.go", - "tun_endpoint_refs.go", - "tun_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/link/tun/tun_endpoint_refs.go b/pkg/tcpip/link/tun/tun_endpoint_refs.go new file mode 100644 index 000000000..a3bee1c05 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_endpoint_refs.go @@ -0,0 +1,140 @@ +package tun + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const tunEndpointenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var tunEndpointobj *tunEndpoint + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type tunEndpointRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *tunEndpointRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *tunEndpointRefs) RefType() string { + return fmt.Sprintf("%T", tunEndpointobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *tunEndpointRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *tunEndpointRefs) LogRefs() bool { + return tunEndpointenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *tunEndpointRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *tunEndpointRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if tunEndpointenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *tunEndpointRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if tunEndpointenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *tunEndpointRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if tunEndpointenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *tunEndpointRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/tcpip/link/tun/tun_state_autogen.go b/pkg/tcpip/link/tun/tun_state_autogen.go new file mode 100644 index 000000000..c5773cc11 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_state_autogen.go @@ -0,0 +1,68 @@ +// automatically generated by stateify. + +package tun + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (d *Device) StateTypeName() string { + return "pkg/tcpip/link/tun.Device" +} + +func (d *Device) StateFields() []string { + return []string{ + "Queue", + "endpoint", + "notifyHandle", + "flags", + } +} + +// +checklocksignore +func (d *Device) StateSave(stateSinkObject state.Sink) { + d.beforeSave() + stateSinkObject.Save(0, &d.Queue) + stateSinkObject.Save(1, &d.endpoint) + stateSinkObject.Save(2, &d.notifyHandle) + stateSinkObject.Save(3, &d.flags) +} + +func (d *Device) afterLoad() {} + +// +checklocksignore +func (d *Device) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &d.Queue) + stateSourceObject.Load(1, &d.endpoint) + stateSourceObject.Load(2, &d.notifyHandle) + stateSourceObject.Load(3, &d.flags) +} + +func (r *tunEndpointRefs) StateTypeName() string { + return "pkg/tcpip/link/tun.tunEndpointRefs" +} + +func (r *tunEndpointRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *tunEndpointRefs) beforeSave() {} + +// +checklocksignore +func (r *tunEndpointRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *tunEndpointRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func init() { + state.Register((*Device)(nil)) + state.Register((*tunEndpointRefs)(nil)) +} diff --git a/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go b/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go new file mode 100644 index 000000000..8d82ad324 --- /dev/null +++ b/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go @@ -0,0 +1,6 @@ +// automatically generated by stateify. + +//go:build linux +// +build linux + +package tun diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD deleted file mode 100644 index b8d417b7d..000000000 --- a/pkg/tcpip/link/waitable/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "waitable", - srcs = [ - "waitable.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "waitable_test", - srcs = [ - "waitable_test.go", - ], - library = ":waitable", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/waitable/waitable_state_autogen.go b/pkg/tcpip/link/waitable/waitable_state_autogen.go new file mode 100644 index 000000000..059424fa0 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package waitable diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go deleted file mode 100644 index a71400ee9..000000000 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waitable - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type countedEndpoint struct { - dispatchCount int - writeCount int - attachCount int - - mtu uint32 - capabilities stack.LinkEndpointCapabilities - hdrLen uint16 - linkAddr tcpip.LinkAddress - - dispatcher stack.NetworkDispatcher -} - -func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dispatchCount++ -} - -func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.attachCount++ - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *countedEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *countedEndpoint) MTU() uint32 { - return e.mtu -} - -func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities -} - -func (e *countedEndpoint) MaxHeaderLength() uint16 { - return e.hdrLen -} - -func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e *countedEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - e.writeCount++ - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - e.writeCount += pkts.Len() - return pkts.Len(), nil -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("unimplemented") -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*countedEndpoint) Wait() {} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("unimplemented") -} - -func TestWaitWrite(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Write and check that it goes through. - wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 1; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on dispatches, then try to write. It must go through. - wep.WaitDispatch() - wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on writes, then try to write. It must not go through. - wep.WaitWrite() - wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } -} - -func TestWaitDispatch(t *testing.T) { - ep := &countedEndpoint{} - wep := New(ep) - - // Check that attach happens. - wep.Attach(ep) - if want := 1; ep.attachCount != want { - t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) - } - - // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 1; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on writes, then try to dispatch. It must go through. - wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on dispatches, then try to dispatch. It must not go through. - wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } -} - -func TestOtherMethods(t *testing.T) { - const ( - mtu = 0xdead - capabilities = 0xbeef - hdrLen = 0x1234 - linkAddr = "test address" - ) - ep := &countedEndpoint{ - mtu: mtu, - capabilities: capabilities, - hdrLen: hdrLen, - linkAddr: linkAddr, - } - wep := New(ep) - - if v := wep.MTU(); v != mtu { - t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) - } - - if v := wep.Capabilities(); v != capabilities { - t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) - } - - if v := wep.MaxHeaderLength(); v != hdrLen { - t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) - } - - if v := wep.LinkAddress(); v != linkAddr { - t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) - } -} diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD deleted file mode 100644 index 7b1ff44f4..000000000 --- a/pkg/tcpip/network/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "ip_test", - size = "small", - srcs = [ - "ip_test.go", - "multicast_group_test.go", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD deleted file mode 100644 index 6fa1aee18..000000000 --- a/pkg/tcpip/network/arp/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "arp", - srcs = [ - "arp.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "arp_test", - size = "small", - srcs = ["arp_test.go"], - deps = [ - ":arp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":arp", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - ], -) diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go new file mode 100644 index 000000000..5cd8535e3 --- /dev/null +++ b/pkg/tcpip/network/arp/arp_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package arp diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go deleted file mode 100644 index 5fcbfeaa2..000000000 --- a/pkg/tcpip/network/arp/arp_test.go +++ /dev/null @@ -1,680 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp_test - -import ( - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - nicID = 1 - - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") -) - -var ( - stackAddr = testutil.MustParse4("10.0.0.1") - remoteAddr = testutil.MustParse4("10.0.0.2") - unknownAddr = testutil.MustParse4("10.0.0.3") -) - -type eventType uint8 - -const ( - entryAdded eventType = iota - entryChanged - entryRemoved -) - -func (t eventType) String() string { - switch t { - case entryAdded: - return "add" - case entryChanged: - return "change" - case entryRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -type eventInfo struct { - eventType eventType - nicID tcpip.NICID - entry stack.NeighborEntry -} - -func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) -} - -// arpDispatcher implements NUDDispatcher to validate the dispatching of -// events upon certain NUD state machine events. -type arpDispatcher struct { - // C is where events are queued - C chan eventInfo -} - -var _ stack.NUDDispatcher = (*arpDispatcher)(nil) - -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryChanged, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryRemoved, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) nextEvent() (eventInfo, bool) { - select { - case event := <-d.C: - return event, true - default: - return eventInfo{}, false - } -} - -type testContext struct { - s *stack.Stack - linkEP *channel.Endpoint - nudDisp arpDispatcher -} - -func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext { - t.Helper() - - tc := testContext{ - nudDisp: arpDispatcher{ - C: make(chan eventInfo, eventDepth), - }, - } - - tc.s = stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - NUDDisp: &tc.nudDisp, - Clock: &faketime.NullClock{}, - }) - - tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr) - tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(tc.linkEP) - if testing.Verbose() { - wep = sniffer.New(wep) - } - if err := tc.s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - - if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %s", err) - } - - tc.s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - return tc -} - -func (c *testContext) cleanup() { - c.linkEP.Close() -} - -func TestMalformedPacket(t *testing.T) { - c := makeTestContext(t, 0, 0) - defer c.cleanup() - - v := make(buffer.View, header.ARPSize) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDisabledEndpoint(t *testing.T) { - c := makeTestContext(t, 0, 0) - defer c.cleanup() - - ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) - if err != nil { - t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err) - } - ep.Disable() - - v := make(buffer.View, header.ARPSize) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDirectReply(t *testing.T) { - c := makeTestContext(t, 0, 0) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPReply) - - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - copy(h.HardwareAddressTarget(), stackLinkAddr) - copy(h.ProtocolAddressTarget(), stackAddr) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDirectRequest(t *testing.T) { - c := makeTestContext(t, 1, 1) - defer c.cleanup() - - tests := []struct { - name string - senderAddr tcpip.Address - senderLinkAddr tcpip.LinkAddress - targetAddr tcpip.Address - isValid bool - }{ - { - name: "Loopback", - senderAddr: stackAddr, - senderLinkAddr: stackLinkAddr, - targetAddr: stackAddr, - isValid: true, - }, - { - name: "Remote", - senderAddr: remoteAddr, - senderLinkAddr: remoteLinkAddr, - targetAddr: stackAddr, - isValid: true, - }, - { - name: "RemoteInvalidTarget", - senderAddr: remoteAddr, - senderLinkAddr: remoteLinkAddr, - targetAddr: unknownAddr, - isValid: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - packetsRecv := c.s.Stats().ARP.PacketsReceived.Value() - requestsRecv := c.s.Stats().ARP.RequestsReceived.Value() - requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() - outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value() - - // Inject an incoming ARP request. - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), test.senderLinkAddr) - copy(h.ProtocolAddressSender(), test.senderAddr) - copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - })) - - if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) - } - if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) - } - - if !test.isValid { - // No packets should be sent after receiving an invalid ARP request. - // There is no need to perform a blocking read here, since packets are - // sent in the same function that handles ARP requests. - if pkt, ok := c.linkEP.Read(); ok { - t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) - } - if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want { - t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want) - } - if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) - } - - return - } - - if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) - } - - // Verify an ARP response was sent. - pi, ok := c.linkEP.Read() - if !ok { - t.Fatal("expected ARP response to be sent, got none") - } - - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.NetworkHeader().View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want) - } - - // Verify the sender was saved in the neighbor cache. - if got, ok := c.nudDisp.nextEvent(); ok { - want := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: test.senderLinkAddr, - State: stack.Stale, - }, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { - t.Errorf("got invalid event (-want +got):\n%s", diff) - } - } else { - t.Fatal("event didn't arrive") - } - - neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) - if err != nil { - t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff) - } - t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr) - } - neighborByAddr[n.Addr] = n - } - - neigh, ok := neighborByAddr[test.senderAddr] - if !ok { - t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr) - } - if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { - t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) - } - if got, want := neigh.State, stack.Stale; got != want { - t.Errorf("got neighbor State = %s, want = %s", got, want) - } - - // No more events should be dispatched - for { - event, ok := c.nudDisp.nextEvent() - if !ok { - break - } - t.Errorf("unexpected %s", event) - } - }) - } -} - -var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) - -type testLinkEndpoint struct { - stack.LinkEndpoint - - writeErr tcpip.Error -} - -func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if t.writeErr != nil { - return t.writeErr - } - - return t.LinkEndpoint.WritePacket(r, protocol, pkt) -} - -func TestLinkAddressRequest(t *testing.T) { - const nicID = 1 - - testAddr := tcpip.Address([]byte{1, 2, 3, 4}) - - tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - linkErr tcpip.Error - expectedErr tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - expectedRequestsSent uint64 - expectedRequestBadLocalAddressErrors uint64 - expectedRequestInterfaceHasNoLocalAddressErrors uint64 - expectedRequestDroppedErrors uint64 - }{ - { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - localAddr: "", - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - localAddr: "", - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with unassigned address", - nicAddr: stackAddr, - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: &tcpip.ErrBadLocalAddress{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with unassigned address", - nicAddr: stackAddr, - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: &tcpip.ErrBadLocalAddress{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with no local address available", - nicAddr: "", - localAddr: "", - remoteLinkAddr: remoteLinkAddr, - expectedErr: &tcpip.ErrNetworkUnreachable{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 1, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with no local address available", - nicAddr: "", - localAddr: "", - remoteLinkAddr: "", - expectedErr: &tcpip.ErrNetworkUnreachable{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 1, - expectedRequestDroppedErrors: 0, - }, - { - name: "Link error", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - linkErr: &tcpip.ErrInvalidEndpointState{}, - expectedErr: &tcpip.ErrInvalidEndpointState{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) - } - linkRes, ok := ep.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) - } - - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) - } - } - - { - err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) - } - } - - if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { - t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) - } - if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors) - } - if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) - } - if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) - } - - if test.expectedErr != nil { - return - } - - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - - rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) - } - }) - } -} - -func TestDADARPRequestPacket(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - }, - }), ipv4.NewProtocol}, - Clock: clock, - }) - e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err) - } else if res != stack.DADStarting { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) - } - - clock.RunImmediatelyScheduledJobs() - pkt, ok := e.Read() - if !ok { - t.Fatal("expected to send an ARP request") - } - - if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - - req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if !req.IsValid() { - t.Errorf("got req.IsValid() = false, want = true") - } - if got := req.Op(); got != header.ARPRequest { - t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr) - } - if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any { - t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any) - } - if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { - t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want) - } - if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr { - t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go deleted file mode 100644 index 0df39ae81..000000000 --- a/pkg/tcpip/network/arp/stats_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD deleted file mode 100644 index 872165866..000000000 --- a/pkg/tcpip/network/hash/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hash", - srcs = ["hash.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go new file mode 100644 index 000000000..9467fe298 --- /dev/null +++ b/pkg/tcpip/network/hash/hash_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hash diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD deleted file mode 100644 index 274f09092..000000000 --- a/pkg/tcpip/network/internal/fragmentation/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - visibility = [ - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "fragmentation_test.go", - "reassembler_test.go", - ], - library = ":fragmentation", - deps = [ - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go new file mode 100644 index 000000000..21c5774e9 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,68 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *reassemblerList) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerList" +} + +func (l *reassemblerList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *reassemblerList) beforeSave() {} + +// +checklocksignore +func (l *reassemblerList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *reassemblerList) afterLoad() {} + +// +checklocksignore +func (l *reassemblerList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *reassemblerEntry) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry" +} + +func (e *reassemblerEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *reassemblerEntry) beforeSave() {} + +// +checklocksignore +func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *reassemblerEntry) afterLoad() {} + +// +checklocksignore +func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*reassemblerList)(nil)) + state.Register((*reassemblerEntry)(nil)) +} diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go deleted file mode 100644 index dadfc28cc..000000000 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,648 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "errors" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// reassembleTimeout is dummy timeout used for testing, where the clock never -// advances. -const reassembleTimeout = 1 - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -func pkt(size int, pieces ...string) *stack.PacketBuffer { - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv(size, pieces...), - }) -} - -type processInput struct { - id FragmentID - first uint16 - last uint16 - more bool - proto uint8 - pkt *stack.PacketBuffer -} - -type processOutput struct { - vv buffer.VectorisedView - proto uint8 - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Next Header protocol mismatch", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), proto: 6, done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) - firstFragmentProto := c.in[0].proto - for i, in := range c.in { - resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) - if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, err) - } - if done != c.out[i].done { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", - in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) - } - if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) - } - if firstFragmentProto != proto { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", - in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) - } - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - const ( - reassemblyTimeout = time.Millisecond - protocol = 0xff - ) - - type fragment struct { - first uint16 - last uint16 - more bool - data string - } - - type event struct { - // name is a nickname of this event. - name string - - // clockAdvance is a duration to advance the clock. The clock advances - // before a fragment specified in the fragment field is processed. - clockAdvance time.Duration - - // fragment is a fragment to process. This can be nil if there is no - // fragment to process. - fragment *fragment - - // expectDone is true if the fragmentation instance should report the - // reassembly is done after the fragment is processd. - expectDone bool - - // memSizeAfterEvent is the expected memory size of the fragmentation - // instance after the event. - memSizeAfterEvent int - } - - memSizeOfFrags := func(frags ...*fragment) int { - var size int - for _, frag := range frags { - size += pkt(len(frag.data), frag.data).MemSize() - } - return size - } - - half1 := &fragment{first: 0, last: 0, more: true, data: "0"} - half2 := &fragment{first: 1, last: 1, more: false, data: "1"} - - tests := []struct { - name string - events []event - }{ - { - name: "half1 and half2 are reassembled successfully", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half2", - fragment: half2, - expectDone: true, - memSizeAfterEvent: 0, - }, - }, - }, - { - name: "half1 timeout, half2 timeout", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - { - name: "half2", - fragment: half2, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) - for _, event := range test.events { - clock.Advance(event.clockAdvance) - if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) - if err != nil { - t.Fatalf("%s: f.Process failed: %s", event.name, err) - } - if done != event.expectDone { - t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) - } - } - if got, want := f.memSize, event.memSizeAfterEvent; got != want { - t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) - } - } - }) - } -} - -func TestMemoryLimits(t *testing.T) { - lowLimit := pkt(1, "0").MemSize() - highLimit := 3 * lowLimit // Allow at most 3 such packets. - f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { - t.Fatal(err) - } - // Send first fragment with id = 1. - if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { - t.Fatal(err) - } - // Send first fragment with id = 2. - if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { - t.Fatal(err) - } - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { - t.Fatal(err) - } - - if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - memSize := pkt(1, "0").MemSize() - f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { - t.Fatal(err) - } - // Send the same packet again. - if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { - t.Fatal(err) - } - - if got, want := f.memSize, memSize; got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} - -func TestErrors(t *testing.T) { - tests := []struct { - name string - blockSize uint16 - first uint16 - last uint16 - more bool - data string - err error - }{ - { - name: "exact block size without more", - blockSize: 2, - first: 2, - last: 3, - more: false, - data: "01", - }, - { - name: "exact block size with more", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "01", - }, - { - name: "exact block size with more and extra data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "012", - err: ErrInvalidArgs, - }, - { - name: "exact block size with more and too little data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size with more", - blockSize: 2, - first: 2, - last: 2, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size without more", - blockSize: 2, - first: 2, - last: 2, - more: false, - data: "0", - }, - { - name: "first not a multiple of block size", - blockSize: 2, - first: 3, - last: 4, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - { - name: "first more than last", - blockSize: 2, - first: 4, - last: 3, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) - if !errors.Is(err, test.err) { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) - } - if done { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) - } - }) - } -} - -type fragmentInfo struct { - remaining int - copied int - offset int - more bool -} - -func TestPacketFragmenter(t *testing.T) { - const ( - reserve = 60 - proto = 0 - ) - - tests := []struct { - name string - fragmentPayloadLen uint32 - transportHeaderLen int - payloadSize int - wantFragments []fragmentInfo - }{ - { - name: "Packet exactly fits in MTU", - fragmentPayloadLen: 1280, - transportHeaderLen: 0, - payloadSize: 1280, - wantFragments: []fragmentInfo{ - {remaining: 0, copied: 1280, offset: 0, more: false}, - }, - }, - { - name: "Packet exactly does not fit in MTU", - fragmentPayloadLen: 1000, - transportHeaderLen: 0, - payloadSize: 1001, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 1000, offset: 0, more: true}, - {remaining: 0, copied: 1, offset: 1000, more: false}, - }, - }, - { - name: "Packet has a transport header", - fragmentPayloadLen: 560, - transportHeaderLen: 40, - payloadSize: 560, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 560, offset: 0, more: true}, - {remaining: 0, copied: 40, offset: 560, more: false}, - }, - }, - { - name: "Packet has a huge transport header", - fragmentPayloadLen: 500, - transportHeaderLen: 1300, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {remaining: 3, copied: 500, offset: 0, more: true}, - {remaining: 2, copied: 500, offset: 500, more: true}, - {remaining: 1, copied: 500, offset: 1000, more: true}, - {remaining: 0, copied: 300, offset: 1500, more: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - originalPayload := stack.PayloadSince(pkt.TransportHeader()) - var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) - for i := 0; ; i++ { - fragPkt, offset, copied, more := pf.BuildNextFragment() - wantFragment := test.wantFragments[i] - if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { - t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) - } - if copied != wantFragment.copied { - t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) - } - if offset != wantFragment.offset { - t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) - } - if more != wantFragment.more { - t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) - } - if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) - } - if got := fragPkt.AvailableHeaderBytes(); got != reserve { - t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) - } - if got := fragPkt.TransportHeader().View().Size(); got != 0 { - t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) - } - reassembledPayload.AppendViews(fragPkt.Data().Views()) - if !more { - if i != len(test.wantFragments)-1 { - t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) - } - break - } - } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" { - t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type testTimeoutHandler struct { - pkt *stack.PacketBuffer -} - -func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { - h.pkt = pkt -} - -func TestTimeoutHandler(t *testing.T) { - const ( - proto = 99 - ) - - pk1 := pkt(1, "1") - pk2 := pkt(1, "2") - - type processParam struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - } - - tests := []struct { - name string - params []processParam - wantError bool - wantPkt *stack.PacketBuffer - }{ - { - name: "onTimeout runs", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "no first fragment", - params: []processParam{ - { - first: 1, - last: 1, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: nil, - }, - { - name: "second pkt is ignored", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - { - first: 0, - last: 0, - more: true, - pkt: pk2, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "invalid args - first is greater than last", - params: []processParam{ - { - first: 1, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: true, - wantPkt: nil, - }, - } - - id := FragmentID{ID: 0} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - handler := &testTimeoutHandler{pkt: nil} - - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - - for _, p := range test.params { - if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { - t.Errorf("f.Process error = %s", err) - } - } - if !test.wantError { - r, ok := f.reassemblers[id] - if !ok { - t.Fatal("Reassembler not found") - } - f.release(r, true) - } - switch { - case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView()) - case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView()) - case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" { - t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go new file mode 100644 index 000000000..673bb11b0 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go @@ -0,0 +1,221 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *reassemblerList) Len() (count int) { + for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *reassemblerList) PushFront(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *reassemblerList) PushBack(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + bLinker := reassemblerElementMapper{}.linkerFor(b) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + aLinker := reassemblerElementMapper{}.linkerFor(a) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *reassemblerList) Remove(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go deleted file mode 100644 index cfd9f00ef..000000000 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "bytes" - "math" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type processParams struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - wantDone bool - wantError error -} - -func TestReassemblerProcess(t *testing.T) { - const proto = 99 - - v := func(size int) buffer.View { - payload := buffer.NewView(size) - for i := 1; i < size; i++ { - payload[i] = uint8(i) * 3 - } - return payload - } - - pkt := func(sizes ...int) *stack.PacketBuffer { - var vv buffer.VectorisedView - for _, size := range sizes { - vv.AppendView(v(size)) - } - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - } - - var tests = []struct { - name string - params []processParams - want []hole - wantPkt *stack.PacketBuffer - }{ - { - name: "No fragments", - params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, - }, - { - name: "One fragment at beginning", - params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, - {first: 2, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment in the middle", - params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, - {first: 0, last: 0, filled: false, final: false}, - {first: 3, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment at the end", - params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, - {first: 0, last: 0, filled: false}, - }, - }, - { - name: "One fragment completing a packet", - params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: true}, - }, - wantPkt: pkt(2), - }, - { - name: "Two fragments completing a packet", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a duplicate", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a partial duplicate", - params: []processParams{ - {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, - {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 3, filled: true, final: false}, - {first: 4, last: 5, filled: true, final: true}, - }, - wantPkt: pkt(4, 2), - }, - { - name: "Two overlapping fragments", - params: []processParams{ - {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, - {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, - }, - want: []hole{ - {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, - {first: 11, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "Two final fragments with different ends", - params: []processParams{ - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, - {first: 0, last: 9, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate, with different ends", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - var resPkt *stack.PacketBuffer - var isDone bool - for _, param := range test.params { - pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) - if done != param.wantDone || err != param.wantError { - t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) - } - if done { - resPkt = pkt - isDone = true - } - } - - ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } - cmpPktData := func(a, b *stack.PacketBuffer) bool { - if a == nil || b == nil { - return a == b - } - return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView()) - } - - if isDone { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - // Do not compare pkt in hole. Data will be altered. - cmp.Comparer(ignorePkt), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { - t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) - } - } else { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - cmp.Comparer(cmpPktData), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD deleted file mode 100644 index fd944ce99..000000000 --- a/pkg/tcpip/network/internal/ip/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ip", - srcs = [ - "duplicate_address_detection.go", - "errors.go", - "generic_multicast_protocol.go", - "stats.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ip_x_test", - size = "small", - srcs = [ - "duplicate_address_detection_test.go", - "generic_multicast_protocol_test.go", - ], - deps = [ - ":ip", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/faketime", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go deleted file mode 100644 index 24687cf06..000000000 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ /dev/null @@ -1,381 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type mockDADProtocol struct { - t *testing.T - - mu struct { - sync.Mutex - - dad ip.DAD - sentNonces map[tcpip.Address][][]byte - } -} - -func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) { - m.mu.Lock() - defer m.mu.Unlock() - - m.t = t - opts.Protocol = m - m.mu.dad.Init(&m.mu, c, opts) - m.initLocked() -} - -func (m *mockDADProtocol) initLocked() { - m.mu.sentNonces = make(map[tcpip.Address][][]byte) -} - -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce) - return nil -} - -func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - sentNonces := make(map[tcpip.Address][][]byte) - for _, a := range addrs { - sentNonces[a] = append(sentNonces[a], nil) - } - - return m.checkWithNonce(sentNonces) -} - -func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string { - m.mu.Lock() - defer m.mu.Unlock() - - diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces) - m.initLocked() - return diff -} - -func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.dad.CheckDuplicateAddressLocked(addr, h) -} - -func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.StopLocked(addr, reason) -} - -func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce) -} - -func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.SetConfigsLocked(c) -} - -const ( - addr1 = tcpip.Address("\x01") - addr2 = tcpip.Address("\x02") - addr3 = tcpip.Address("\x03") - addr4 = tcpip.Address("\x04") -) - -type dadResult struct { - Addr tcpip.Address - R stack.DADResult -} - -func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) { - return func(r stack.DADResult) { - ch <- dadResult{Addr: a, R: r} - } -} - -func TestDADCheckDuplicateAddress(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dad.init(t, stack.DADConfigurations{}, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 2) - - // DAD should initially be disabled. - if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) - } - // Wait for any initially fired timers to complete. - clock.RunImmediatelyScheduledJobs() - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Enable and request DAD. - dadConfigs1 := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs1) - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - // The second request for DAD on the same address should use the original - // request since it has not completed yet. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dadConfigs2 := stack.DADConfigurations{ - DupAddrDetectTransmits: 2, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs2) - // A new address should start a new DAD process. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Make sure DAD for addr1 only resolves after the expected timeout. - const delta = time.Nanosecond - dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer - clock.Advance(dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r) - default: - } - clock.Advance(delta) - for i := 0; i < 2; i++ { - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) - } - } - - // Make sure DAD for addr2 only resolves after the expected timeout. - dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer - clock.Advance(dadConfig2Duration - dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r) - default: - } - clock.Advance(delta) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for addr2 after it resolved. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadConfig2Duration) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore results. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} - -func TestDADStop(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.init(t, dadConfigs, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 1) - - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr1, &stack.DADAborted{}) - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr2, &stack.DADDupAddrDetected{}) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for an address we stopped DAD on. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.RunImmediatelyScheduledJobs() - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore updates. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} - -func TestNonce(t *testing.T) { - const ( - nonceSize = 2 - - extendRequestAttempts = 2 - - dupAddrDetectTransmits = 2 - extendTransmits = 5 - ) - - var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte - for i := range secureRNGBytes { - secureRNGBytes[i] = byte(i) - } - - tests := []struct { - name string - mockedReceivedNonce []byte - expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition - expectedTransmits int - }{ - { - name: "not matching", - mockedReceivedNonce: []byte{0, 0}, - expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual}, - expectedTransmits: dupAddrDetectTransmits, - }, - { - name: "matching nonce", - mockedReceivedNonce: secureRNGBytes[:nonceSize], - expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended}, - expectedTransmits: dupAddrDetectTransmits + extendTransmits, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: dupAddrDetectTransmits, - RetransmitTimer: time.Second, - } - - var secureRNG bytes.Reader - secureRNG.Reset(secureRNGBytes[:]) - dad.init(t, dadConfigs, ip.DADOptions{ - Clock: clock, - SecureRNG: &secureRNG, - NonceSize: nonceSize, - ExtendDADTransmits: extendTransmits, - }) - - ch := make(chan dadResult, 1) - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - - clock.RunImmediatelyScheduledJobs() - for i, want := range test.expectedResults { - if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { - t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) - } - } - - for i := 0; i < test.expectedTransmits; i++ { - if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{ - addr1: { - secureRNGBytes[nonceSize*i:][:nonceSize], - }, - }); diff != "" { - t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff) - } - - clock.Advance(dadConfigs.RetransmitTimer) - } - - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore updates. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - }) - } -} diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go deleted file mode 100644 index 1261ad414..000000000 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ /dev/null @@ -1,808 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" -) - -const maxUnsolicitedReportDelay = time.Second - -var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) - -type mockMulticastGroupProtocolProtectedFields struct { - sync.RWMutex - - genericMulticastGroup ip.GenericMulticastProtocolState - sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int - makeQueuePackets bool - disabled bool -} - -type mockMulticastGroupProtocol struct { - t *testing.T - - skipProtocolAddress tcpip.Address - - mu mockMulticastGroupProtocolProtectedFields -} - -func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { - m.mu.Lock() - defer m.mu.Unlock() - m.initLocked() - opts.Protocol = m - m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) -} - -func (m *mockMulticastGroupProtocol) initLocked() { - m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) -} - -func (m *mockMulticastGroupProtocol) setEnabled(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.disabled = !v -} - -func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.makeQueuePackets = v -} - -func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.JoinGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleReportLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) -} - -func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) -} - -func (m *mockMulticastGroupProtocol) makeAllNonMember() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.MakeAllNonMemberLocked() -} - -func (m *mockMulticastGroupProtocol) initializeGroups() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.InitializeGroupsLocked() -} - -func (m *mockMulticastGroupProtocol) sendQueuedReports() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.SendQueuedReportsLocked() -} - -// Enabled implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be read locked. -func (m *mockMulticastGroupProtocol) Enabled() bool { - if m.mu.TryLock() { - m.mu.Unlock() // +checklocksforce: TryLock. - m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") - } - - return !m.mu.disabled -} - -// SendReport implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { - if m.mu.TryLock() { - m.mu.Unlock() // +checklocksforce: TryLock. - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() // +checklocksforce: TryLock. - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - - m.mu.sendReportGroupAddrCount[groupAddress]++ - return !m.mu.makeQueuePackets, nil -} - -// SendLeave implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { - if m.mu.TryLock() { - m.mu.Unlock() // +checklocksforce: TryLock. - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() // +checklocksforce: TryLock. - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - - m.mu.sendLeaveGroupAddrCount[groupAddress]++ - return nil -} - -// ShouldPerformProtocol implements ip.MulticastGroupProtocol. -func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { - return groupAddress != m.skipProtocolAddress -} - -func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendReportGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendReportGroupAddresses { - sendReportGroupAddrCount[a] = 1 - } - - sendLeaveGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddrCount[a] = 1 - } - - diff := cmp.Diff( - &mockMulticastGroupProtocol{ - mu: mockMulticastGroupProtocolProtectedFields{ - sendReportGroupAddrCount: sendReportGroupAddrCount, - sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, - }, - }, - m, - cmp.AllowUnexported(mockMulticastGroupProtocol{}), - cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), - // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t - cmp.FilterPath( - func(p cmp.Path) bool { - switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": - return true - default: - return false - } - }, - cmp.Ignore(), - ), - ) - m.initLocked() - return diff -} - -func TestJoinGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendReports bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendReports: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendReports: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(0)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining a group should send a report immediately and another after - // a random interval between 0 and the maximum unsolicited report delay. - mgp.joinGroup(test.addr) - if test.shouldSendReports { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestLeaveGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendMessages bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendMessages: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendMessages: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(1)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(test.addr) - if test.shouldSendMessages { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Leaving a group should send a leave report immediately and cancel any - // delayed reports. - { - - if !mgp.leaveGroup(test.addr) { - t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) - } - } - if test.shouldSendMessages { - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleReport(t *testing.T) { - tests := []struct { - name string - reportAddr tcpip.Address - expectReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - reportAddr: "", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Unpecified any", - reportAddr: "\x00", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified", - reportAddr: addr1, - expectReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - reportAddr: addr3, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - reportAddr: addr4, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(2)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report for a group we have a timer scheduled for should - // cancel our delayed report timer for the group. - mgp.handleReport(test.reportAddr) - if len(test.expectReportsFor) != 0 { - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleQuery(t *testing.T) { - tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectQueriedReportsFor []tcpip.Address - expectDelayedReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectQueriedReportsFor: []tcpip.Address{addr1}, - expectDelayedReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should make us reschedule our delayed report timer - // to some time within the new max response delay. - mgp.handleQuery(test.queryAddr, test.maxDelay) - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The groups that were not affected by the query should still send a - // report after the max unsolicited report delay. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestJoinCount(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: time.Second, - }) - - // Set the join count to 2 for a group. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // Only the first join should trigger a report to be sent. - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should still be considered joined after leaving once. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // A leave report should only be sent once the join count reaches 0. - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Leaving once more should actually remove us from the group. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should no longer be joined so we should not have anything to - // leave. - if mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should send the leave reports for each but still consider them locally - // joined. - mgp.makeAllNonMember() - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !mgp.isLocallyJoined(group) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) - } - } - - // Should send the initial set of unsolcited reports. - mgp.initializeGroups() - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -// TestGroupStateNonMember tests that groups do not send packets when in the -// non-member state, but are still considered locally joined. -func TestGroupStateNonMember(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - mgp.setEnabled(false) - - // Joining groups should not send any reports. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should not send any reports. - mgp.handleQuery(addr1, time.Nanosecond) - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Nanosecond) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Leaving groups should not send any leave messages. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestQueuedPackets(t *testing.T) { - clock := faketime.NewManualClock() - mgp := mockMulticastGroupProtocol{t: t} - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining should trigger a SendReport, but mgp should report that we did not - // send the packet. - mgp.setQueuePackets(true) - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report timer should have been cancelled since we did not send - // the initial report earlier. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send (we should be idle). - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query but mock being unable to send reports again. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to send reports again - we should have a packet queued to - // send. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query again, but mock being unable to send reports. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report should should transition us into the idle member state, - // even if we had a packet queued. We should no longer have any packets to - // send. - mgp.handleReport(addr1) - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // When we fail to send the initial set of reports, incoming reports should - // not affect a newly joined group's reports from being sent. - mgp.setQueuePackets(true) - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.handleReport(addr2) - // Attempting to send queued reports while still unable to send reports should - // not change the host state. - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} diff --git a/pkg/tcpip/network/internal/ip/ip_state_autogen.go b/pkg/tcpip/network/internal/ip/ip_state_autogen.go new file mode 100644 index 000000000..360922bfe --- /dev/null +++ b/pkg/tcpip/network/internal/ip/ip_state_autogen.go @@ -0,0 +1,32 @@ +// automatically generated by stateify. + +package ip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *ErrMessageTooLong) StateTypeName() string { + return "pkg/tcpip/network/internal/ip.ErrMessageTooLong" +} + +func (e *ErrMessageTooLong) StateFields() []string { + return []string{} +} + +func (e *ErrMessageTooLong) beforeSave() {} + +// +checklocksignore +func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrMessageTooLong) afterLoad() {} + +// +checklocksignore +func (e *ErrMessageTooLong) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*ErrMessageTooLong)(nil)) +} diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD deleted file mode 100644 index a180e5c75..000000000 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - srcs = ["testutil.go"], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/internal/fragmentation:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - "//pkg/tcpip/tests/integration:__pkg__", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go deleted file mode 100644 index 605e9ef8d..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil defines types and functions used to test Network Layer -// functionality such as IP fragmentation. -package testutil - -import ( - "fmt" - "math/rand" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// MockLinkEndpoint is an endpoint used for testing, it stores packets written -// to it and can mock errors. -type MockLinkEndpoint struct { - // WrittenPackets is where packets written to the endpoint are stored. - WrittenPackets []*stack.PacketBuffer - - mtu uint32 - err tcpip.Error - allowPackets int -} - -// NewMockLinkEndpoint creates a new MockLinkEndpoint. -// -// err is the error that will be returned once allowPackets packets are written -// to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { - return &MockLinkEndpoint{ - mtu: mtu, - err: err, - allowPackets: allowPackets, - } -} - -// MTU implements LinkEndpoint.MTU. -func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } - -// Capabilities implements LinkEndpoint.Capabilities. -func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - var n int - - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, protocol, pkt); err != nil { - return n, err - } - n++ - } - - return n, nil -} - -// Attach implements LinkEndpoint.Attach. -func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*MockLinkEndpoint) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*MockLinkEndpoint) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -// MakeRandPkt generates a randomized packet. transportHeaderLength indicates -// how many random bytes will be copied in the Transport Header. -// extraHeaderReserveLength indicates how much extra space will be reserved for -// the other headers. The payload is made from Views of the sizes listed in -// viewSizes. -func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { - var views buffer.VectorisedView - - for _, s := range viewSizes { - newView := buffer.NewView(s) - if _, err := rand.Read(newView); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - views.AppendView(newView) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, - Data: views, - }) - pkt.NetworkProtocolNumber = proto - if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go deleted file mode 100644 index 771b9173a..000000000 --- a/pkg/tcpip/network/ip_test.go +++ /dev/null @@ -1,2034 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "fmt" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const nicID = 1 - -var ( - localIPv4Addr = testutil.MustParse4("10.0.0.1") - remoteIPv4Addr = testutil.MustParse4("10.0.0.2") - ipv4SubnetAddr = testutil.MustParse4("10.0.0.0") - ipv4SubnetMask = testutil.MustParse4("255.255.255.0") - ipv4Gateway = testutil.MustParse4("10.0.0.3") - localIPv6Addr = testutil.MustParse6("a00::1") - remoteIPv6Addr = testutil.MustParse6("a00::2") - ipv6SubnetAddr = testutil.MustParse6("a00::") - ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00") - ipv6Gateway = testutil.MustParse6("a00::3") -) - -var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: localIPv4Addr, - PrefixLen: 24, -} - -var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: localIPv6Addr, - PrefixLen: 120, -} - -type transportError struct { - origin tcpip.SockErrOrigin - typ uint8 - code uint8 - info uint32 - kind stack.TransportErrorKind -} - -// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. -// The former is used to pretend that it's a link endpoint so that we can -// inspect packets written by the network endpoints. The latter is used to -// pretend that it's the network stack so that it can inspect incoming packets -// that have been handled by the network endpoints. -// -// Packets are checked by comparing their fields/values against the expected -// values stored in the test object itself. -type testObject struct { - t *testing.T - protocol tcpip.TransportProtocolNumber - contents []byte - srcAddr tcpip.Address - dstAddr tcpip.Address - v4 bool - transErr transportError - - dataCalls int - controlCalls int - rawCalls int -} - -// checkValues verifies that the transport protocol, data contents, src & dst -// addresses of a packet match what's expected. If any field doesn't match, the -// test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) { - if protocol != t.protocol { - t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) - } - - if srcAddr != t.srcAddr { - t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) - } - - if dstAddr != t.dstAddr { - t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) - } - - if len(v) != len(t.contents) { - t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) - } - - for i := range t.contents { - if t.contents[i] != v[i] { - t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) - } - } -} - -// DeliverTransportPacket is called by network endpoints after parsing incoming -// packets. This is used by the test object to verify that the results of the -// parsing are expected. -func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - netHdr := pkt.Network() - t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress()) - t.dataCalls++ - return stack.TransportPacketHandled -} - -// DeliverTransportError is called by network endpoints after parsing -// incoming control (ICMP) packets. This is used by the test object to verify -// that the results of the parsing are expected. -func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { - t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local) - if diff := cmp.Diff( - t.transErr, - transportError{ - origin: transErr.Origin(), - typ: transErr.Type(), - code: transErr.Code(), - info: transErr.Info(), - kind: transErr.Kind(), - }, - cmp.AllowUnexported(transportError{}), - ); diff != "" { - t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) - } - t.controlCalls++ -} - -func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { - t.rawCalls++ -} - -// Attach is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (*testObject) IsAttached() bool { - return true -} - -// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that -// matches the linux loopback MTU. -func (*testObject) MTU() uint32 { - return 65536 -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*testObject) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (*testObject) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*testObject) Wait() {} - -// WritePacket is called by network endpoints after producing a packet and -// writing it to the link endpoint. This is used by the test object to verify -// that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - var prot tcpip.TransportProtocolNumber - var srcAddr tcpip.Address - var dstAddr tcpip.Address - - if t.v4 { - h := header.IPv4(pkt.NetworkHeader().View()) - prot = tcpip.TransportProtocolNumber(h.Protocol()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - - } else { - h := header.IPv6(pkt.NetworkHeader().View()) - prot = tcpip.TransportProtocolNumber(h.NextHeader()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - } - t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr) - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*testObject) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("not implemented") -} - -func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - Gateway: ipv4Gateway, - NIC: 1, - }}) - - return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) -} - -func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: ipv6Gateway, - NIC: 1, - }}) - - return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) -} - -func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - e := channel.New(1, mtu, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) - } - - v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) - } - - return s, e -} - -func buildDummyStack(t *testing.T) *stack.Stack { - t.Helper() - - s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) - return s -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - testObject - - mu struct { - sync.RWMutex - disabled bool - } -} - -func (*testInterface) ID() tcpip.NICID { - return nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (t *testInterface) Enabled() bool { - t.mu.RLock() - defer t.mu.RUnlock() - return !t.mu.disabled -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (*testInterface) Spoofing() bool { - return false -} - -func (t *testInterface) setEnabled(v bool) { - t.mu.Lock() - defer t.mu.Unlock() - t.mu.disabled = !v -} - -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { - return nil -} - -func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { - return nil -} - -func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { - return tcpip.AddressWithPrefix{}, nil -} - -func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { - return false -} - -func TestSourceAddressValidation(t *testing.T) { - rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: src, - Dst: localIPv6Addr, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - tests := []struct { - name string - srcAddress tcpip.Address - rxICMP func(*channel.Endpoint, tcpip.Address) - valid bool - }{ - { - name: "IPv4 valid", - srcAddress: "\x01\x02\x03\x04", - rxICMP: rxIPv4ICMP, - valid: true, - }, - { - name: "IPv6 valid", - srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", - rxICMP: rxIPv6ICMP, - valid: true, - }, - { - name: "IPv4 unspecified", - srcAddress: header.IPv4Any, - rxICMP: rxIPv4ICMP, - valid: true, - }, - { - name: "IPv6 unspecified", - srcAddress: header.IPv4Any, - rxICMP: rxIPv6ICMP, - valid: true, - }, - { - name: "IPv4 multicast", - srcAddress: "\xe0\x00\x00\x01", - rxICMP: rxIPv4ICMP, - valid: false, - }, - { - name: "IPv6 multicast", - srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - rxICMP: rxIPv6ICMP, - valid: false, - }, - { - name: "IPv4 broadcast", - srcAddress: header.IPv4Broadcast, - rxICMP: rxIPv4ICMP, - valid: false, - }, - { - name: "IPv4 subnet broadcast", - srcAddress: func() tcpip.Address { - subnet := localIPv4AddrWithPrefix.Subnet() - return subnet.Broadcast() - }(), - rxICMP: rxIPv4ICMP, - valid: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) - test.rxICMP(e, test.srcAddress) - - var wantValid uint64 - if test.valid { - wantValid = 1 - } - - if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { - t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) - } - if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { - t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) - } - }) - } -} - -func TestEnableWhenNICDisabled(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - }{ - { - name: "IPv4", - protocolFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - }, - { - name: "IPv6", - protocolFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var nic testInterface - nic.setEnabled(false) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, - }) - p := s.NetworkProtocolInstance(test.protoNum) - - // We pass nil for all parameters except the NetworkInterface and Stack - // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil) - - // The endpoint should initially be disabled, regardless the NIC's enabled - // status. - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - nic.setEnabled(true) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Attempting to enable the endpoint while the NIC is disabled should - // fail. - nic.setEnabled(false) - err := ep.Enable() - if _, ok := err.(*tcpip.ErrNotPermitted); !ok { - t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) - } - // ep should consider the NIC's enabled status when determining its own - // enabled status so we "enable" the NIC to read just the endpoint's - // enabled status. - nic.setEnabled(true) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Enabling the interface after the NIC has been enabled should succeed. - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - if !ep.Enabled() { - t.Fatal("got ep.Enabled() = false, want = true") - } - - // ep should consider the NIC's enabled status when determining its own - // enabled status. - nic.setEnabled(false) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Disabling the endpoint when the NIC is enabled should make the endpoint - // disabled. - nic.setEnabled(true) - ep.Disable() - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - }) - } -} - -func TestIPv4Send(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, - }, - } - ep := proto.NewEndpoint(&nic, nil) - defer ep.Close() - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Setup the packet buffer. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ep.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - }) - - // Issue the write. - nic.testObject.protocol = 123 - nic.testObject.srcAddr = localIPv4Addr - nic.testObject.dstAddr = remoteIPv4Addr - nic.testObject.contents = payload - - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(r, stack.NetworkHeaderParams{ - Protocol: 123, - TTL: 123, - TOS: stack.DefaultTOS, - }, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestReceive(t *testing.T) { - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - v4 bool - epAddr tcpip.AddressWithPrefix - handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - v4: true, - epAddr: localIPv4Addr.WithPrefix(), - handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { - const totalLen = header.IPv4MinimumSize + 30 /* payload length */ - - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - TTL: ipv4.DefaultTTL, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - }, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - v4: false, - epAddr: localIPv6Addr.WithPrefix(), - handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { - const payloadLen = 30 - view := buffer.NewView(header.IPv6MinimumSize + payloadLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadLen, - TransportProtocol: 10, - HopLimit: ipv6.DefaultTTL, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, - }) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: test.v4, - }, - } - ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) - } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) - } else { - ep.DecRef() - } - - stat := s.Stats().IP.PacketsReceived - if got := stat.Value(); got != 0 { - t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) - } - test.handlePacket(t, ep, &nic) - if nic.testObject.dataCalls != 1 { - t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) - } - if nic.testObject.rawCalls != 1 { - t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) - } - if got := stat.Value(); got != 1 { - t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) - } - }) - } -} - -func TestIPv4ReceiveControl(t *testing.T) { - const ( - mtu = 0xbeef - header.IPv4MinimumSize - dataLen = 8 - ) - - cases := []struct { - name string - expectedCount int - fragmentOffset uint16 - code header.ICMPv4Code - transErr transportError - trunc int - }{ - { - name: "FragmentationNeeded", - expectedCount: 1, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4FragmentationNeeded), - info: mtu, - kind: stack.PacketTooBigTransportError, - }, - trunc: 0, - }, - { - name: "Truncated (missing IPv4 header)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, - }, - { - name: "Truncated (partial offending packet's IP header)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, - }, - { - name: "Truncated (partial offending packet's data)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, - }, - { - name: "Port unreachable", - expectedCount: 1, - fragmentOffset: 0, - code: header.ICMPv4PortUnreachable, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4PortUnreachable), - kind: stack.DestinationPortUnreachableTransportError, - }, - trunc: 0, - }, - { - name: "Non-zero fragment offset", - expectedCount: 0, - fragmentOffset: 100, - code: header.ICMPv4PortUnreachable, - trunc: 0, - }, - { - name: "Zero-length packet", - expectedCount: 0, - fragmentOffset: 100, - code: header.ICMPv4PortUnreachable, - trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + dataLen) - - // Create the outer IPv4 header. - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(view) - c.trunc), - TTL: 20, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Create the ICMP header. - icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) - icmp.SetType(header.ICMPv4DstUnreachable) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) - ip.Encode(&header.IPv4Fields{ - TotalLength: 100, - TTL: 20, - Protocol: 10, - FragmentOffset: c.fragmentOffset, - SrcAddr: localIPv4Addr, - DstAddr: remoteIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - icmp.SetChecksum(0) - checksum := ^header.Checksum(icmp, 0 /* initial */) - icmp.SetChecksum(checksum) - - // Give packet to IPv4 endpoint, dispatcher will validate that - // it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[dataOffset:] - nic.testObject.transErr = c.transErr - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - ep.HandlePacket(pkt) - if want := c.expectedCount; nic.testObject.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) - } - }) - } -} - -func TestIPv4FragmentationReceive(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv4MinimumSize + 24 - - frag1 := buffer.NewView(totalLen) - ip1 := header.IPv4(frag1) - ip1.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 0, - Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip1.SetChecksum(^ip1.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag1[i] = uint8(i) - } - - frag2 := buffer.NewView(totalLen) - ip2 := header.IPv4(frag2) - ip2.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 24, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip2.SetChecksum(^ip2.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag2[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - - // Send first segment. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: frag1.ToVectorisedView(), - }) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls) - } - if nic.testObject.rawCalls != 0 { - t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls) - } - - // Send second segment. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: frag2.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) - } - if nic.testObject.rawCalls != 1 { - t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) - } -} - -func TestIPv6Send(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Setup the packet buffer. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ep.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - }) - - // Issue the write. - nic.testObject.protocol = 123 - nic.testObject.srcAddr = localIPv6Addr - nic.testObject.dstAddr = remoteIPv6Addr - nic.testObject.contents = payload - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(r, stack.NetworkHeaderParams{ - Protocol: 123, - TTL: 123, - TOS: stack.DefaultTOS, - }, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv6ReceiveControl(t *testing.T) { - const ( - mtu = 0xffff - outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" - dataLen = 8 - ) - - newUint16 := func(v uint16) *uint16 { return &v } - - portUnreachableTransErr := transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6DstUnreachable), - code: uint8(header.ICMPv6PortUnreachable), - kind: stack.DestinationPortUnreachableTransportError, - } - - cases := []struct { - name string - expectedCount int - fragmentOffset *uint16 - typ header.ICMPv6Type - code header.ICMPv6Code - transErr transportError - trunc int - }{ - { - name: "PacketTooBig", - expectedCount: 1, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6PacketTooBig), - code: uint8(header.ICMPv6UnusedCode), - info: mtu, - kind: stack.PacketTooBigTransportError, - }, - trunc: 0, - }, - { - name: "Truncated (missing offending packet's IPv6 header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, - }, - { - name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, - }, - { - name: "Truncated (partial offending packet's data)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, - }, - { - name: "Port unreachable", - expectedCount: 1, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "Truncated DstPortUnreachable (partial offending packet's IP header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, - }, - { - name: "DstPortUnreachable for Fragmented, zero offset", - expectedCount: 1, - fragmentOffset: newUint16(0), - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "DstPortUnreachable for Non-zero fragment offset", - expectedCount: 0, - fragmentOffset: newUint16(100), - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "Zero-length packet", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize - if c.fragmentOffset != nil { - dataOffset += header.IPv6FragmentHeaderSize - } - view := buffer.NewView(dataOffset + dataLen) - - // Create the outer IPv6 header. - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) - icmp.SetType(c.typ) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - var extHdrs header.IPv6ExtHdrSerializer - // Build the fragmentation header if needed. - if c.fragmentOffset != nil { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: *c.fragmentOffset, - M: true, - Identification: 0x12345678, - }) - } - - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - ExtensionHeaders: extHdrs, - }) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv6 endpoint, dispatcher will validate that - // it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[dataOffset:] - nic.testObject.transErr = c.transErr - - // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: outerSrcAddr, - Dst: localIPv6Addr, - })) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - ep.HandlePacket(pkt) - if want := c.expectedCount; nic.testObject.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) - } - }) - } -} - -// truncatedPacket returns a PacketBuffer based on a truncated view. If view, -// after truncation, is large enough to hold a network header, it makes part of -// view the packet's NetworkHeader and the rest its Data. Otherwise all of view -// becomes Data. -func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { - v := view[:len(view)-trunc] - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - return pkt -} - -func TestWriteHeaderIncludedPacket(t *testing.T) { - const ( - nicID = 1 - transportProto = 5 - - dataLen = 4 - ) - - dataBuf := [dataLen]byte{1, 2, 3, 4} - data := dataBuf[:] - - ipv4Options := header.IPv4OptionsSerializer{ - &header.IPv4SerializableListEndOption{}, - &header.IPv4SerializableNOPOption{}, - &header.IPv4SerializableListEndOption{}, - &header.IPv4SerializableNOPOption{}, - } - - expectOptions := header.IPv4Options{ - byte(header.IPv4OptionListEndType), - byte(header.IPv4OptionNOPType), - byte(header.IPv4OptionListEndType), - byte(header.IPv4OptionNOPType), - } - - ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} - ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] - - var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte - ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] - if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) - } - if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address - remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView - checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) - expectedErr tcpip.Error - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv4MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv4MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv4 with IHL too small", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv4MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - }) - ip.SetHeaderLength(header.IPv4MinimumSize - 1) - return hdr.View().ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - { - name: "IPv4 too small", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - }) - return buffer.View(ip[:len(ip)-1]).ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - { - name: "IPv4 minimum size", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - }) - return buffer.View(ip).ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv4MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(header.IPv4MinimumSize), - checker.IPPayload(nil), - ) - }, - }, - { - name: "IPv4 with options", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - totalLen := ipHdrLen + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(ipHdrLen)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - Options: ipv4Options, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - if len(netHdr.View()) != hdrLen { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(hdrLen), - checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(expectOptions), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv4 with options and data across views", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - Options: ipv4Options, - }) - vv := buffer.View(ip).ToVectorisedView() - vv.AppendView(data) - return vv - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - if len(netHdr.View()) != hdrLen { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(hdrLen), - checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(expectOptions), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv6MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv6Addr, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv6MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv6 with extension header", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - // NB: we're lying about transport protocol here to verify the raw - // fragment header bytes. - TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv6Addr, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), - checker.IPPayload(ipv6PayloadWithExtHdr), - ) - }, - }, - { - name: "IPv6 minimum size", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv6Addr, - }) - return buffer.View(ip).ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv6MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(header.IPv6MinimumSize), - checker.IPPayload(nil), - ) - }, - }, - { - name: "IPv6 too small", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: remoteIPv4Addr, - }) - return buffer.View(ip[:len(ip)-1]).ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - subTests := []struct { - name string - srcAddr tcpip.Address - }{ - { - name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), - }, - { - name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), - }, - } - - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) - } - defer r.Release() - - { - err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr), - })) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) - } - } - - if test.expectedErr != nil { - return - } - - pkt, ok := e.Read() - if !ok { - t.Fatal("expected a packet to be written") - } - test.checker(t, pkt.Pkt, subTest.srcAddr) - }) - } - }) - } -} - -// Test that the included data in an ICMP error packet conforms to the -// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 -func TestICMPInclusionSize(t *testing.T) { - const ( - replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize - replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize - targetSize4 = header.IPv4MinimumProcessableDatagramSize - targetSize6 = header.IPv6MinimumMTU - // A protocol number that will cause an error response. - reservedProtocol = 254 - ) - - // IPv4 function to create a IP packet and send it to the stack. - // The packet should generate an error response. We can do that by using an - // unknown transport protocol (254). - rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { - totalLen := header.IPv4MinimumSize + len(payload) - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: reservedProtocol, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - vv := hdr.View().ToVectorisedView() - vv.AppendView(buffer.View(payload)) - // Take a copy before InjectInbound takes ownership of vv - // as vv may be changed during the call. - v := vv.ToView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return v - } - - // IPv6 function to create a packet and send it to the stack. - // The packet should be errant in a way that causes the stack to send an - // ICMP error response and have enough data to allow the testing of the - // inclusion of the errant packet. Use `unknown next header' to generate - // the error. - rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload)), - TransportProtocol: reservedProtocol, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, - }) - vv := hdr.View().ToVectorisedView() - vv.AppendView(buffer.View(payload)) - // Take a copy before InjectInbound takes ownership of vv - // as vv may be changed during the call. - v := vv.ToView() - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return v - } - - v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { - // We already know the entire packet is the right size so we can use its - // length to calculate the right payload size to check. - expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize - checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), - checker.SrcAddr(localIPv4Addr), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), - checker.ICMPv4Payload(payload[:expectedPayloadLength]), - ), - ) - } - - v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { - // We already know the entire packet is the right size so we can use its - // length to calculate the right payload size to check. - expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize - checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), - checker.SrcAddr(localIPv6Addr), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6ParamProblem), - checker.ICMPv6Code(header.ICMPv6UnknownHeader), - checker.ICMPv6Payload(payload[:expectedPayloadLength]), - ), - ) - } - tests := []struct { - name string - srcAddress tcpip.Address - injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View - checker func(*testing.T, *stack.PacketBuffer, buffer.View) - payloadLength int // Not including IP header. - linkMTU uint32 // Largest IP packet that the link can send as payload. - replyLength int // Total size of IP/ICMP packet expected back. - }{ - { - name: "IPv4 exact match", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 - replyHeaderLength4, - linkMTU: targetSize4, - replyLength: targetSize4, - }, - { - name: "IPv4 larger MTU", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4, - linkMTU: targetSize4 + 1000, - replyLength: targetSize4, - }, - { - name: "IPv4 smaller MTU", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4, - linkMTU: targetSize4 - 50, - replyLength: targetSize4 - 50, - }, - { - name: "IPv4 payload exceeds", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 + 10, - linkMTU: targetSize4, - replyLength: targetSize4, - }, - { - name: "IPv4 1 byte less", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 - replyHeaderLength4 - 1, - linkMTU: targetSize4, - replyLength: targetSize4 - 1, - }, - { - name: "IPv4 No payload", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: 0, - linkMTU: targetSize4, - replyLength: replyHeaderLength4, - }, - { - name: "IPv6 exact match", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6 - replyHeaderLength6, - linkMTU: targetSize6, - replyLength: targetSize6, - }, - { - name: "IPv6 larger MTU", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6, - linkMTU: targetSize6 + 400, - replyLength: targetSize6, - }, - // NB. No "smaller MTU" test here as less than 1280 is not permitted - // in IPv6. - { - name: "IPv6 payload exceeds", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6, - linkMTU: targetSize6, - replyLength: targetSize6, - }, - { - name: "IPv6 1 byte less", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6 - replyHeaderLength6 - 1, - linkMTU: targetSize6, - replyLength: targetSize6 - 1, - }, - { - name: "IPv6 no payload", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: 0, - linkMTU: targetSize6, - replyLength: replyHeaderLength6, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) - // Allocate and initialize the payload view. - payload := buffer.NewView(test.payloadLength) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - // Default routes for IPv4&6 so ICMP can find a route to the remote - // node when attempting to send the ICMP error Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - v := test.injector(e, test.srcAddress, payload) - pkt, ok := e.Read() - if !ok { - t.Fatal("expected a packet to be written") - } - if got, want := pkt.Pkt.Size(), test.replyLength; got != want { - t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) - } - test.checker(t, pkt.Pkt, v) - }) - } -} - -func TestJoinLeaveAllRoutersGroup(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - protoFactory stack.NetworkProtocolFactory - allRoutersAddr tcpip.Address - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - protoFactory: ipv4.NewProtocol, - allRoutersAddr: header.IPv4AllRoutersGroup, - }, - { - name: "IPv6 Interface Local", - netProto: ipv6.ProtocolNumber, - protoFactory: ipv6.NewProtocol, - allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, - }, - { - name: "IPv6 Link Local", - netProto: ipv6.ProtocolNumber, - protoFactory: ipv6.NewProtocol, - allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, - }, - { - name: "IPv6 Site Local", - netProto: ipv6.ProtocolNumber, - protoFactory: ipv6.NewProtocol, - allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, nicDisabled := range [...]bool{true, false} { - t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - opts := stack.NICOptions{Disabled: nicDisabled} - if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) - } - - if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { - t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) - } else if got { - t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) - } - - if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) - } - if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { - t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) - } else if !got { - t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) - } - - if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) - } - if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { - t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) - } else if got { - t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) - } - }) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index 2257f728e..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "igmp.go", - "ipv4.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = [ - "igmp_test.go", - "ipv4_test.go", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":ipv4", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - ], -) diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go deleted file mode 100644 index 4bd6f462e..000000000 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - nicID = 1 - defaultTTL = 1 - defaultPrefixLength = 24 -) - -var ( - stackAddr = testutil.MustParse4("10.0.0.1") - remoteAddr = testutil.MustParse4("10.0.0.2") - multicastAddr = testutil.MustParse4("224.0.0.3") -) - -// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. Raises a t.Error if -// any field does not match. -func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(srcAddr), - checker.DstAddr(dstAddr), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(igmpType), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - // Create an endpoint of queue size 1, since no more than 1 packets are ever - // queued in the tests in this file. - e := channel.New(1, 1280, linkAddr) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - return e, s, clock -} - -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) { - var options header.IPv4OptionsSerializer - if hasRouterAlertOption { - options = header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: ttl, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: srcAddr, - DstAddr: dstAddr, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(igmpType) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 -// router is detected. V1 present status is expected to be reset when the NIC -// cycles. -func TestIGMPV1Present(t *testing.T) { - e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) - } - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // This NIC will send an IGMPv2 report immediately, before this test can get - // the IGMPv1 General Membership Query in. - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Inject an IGMPv1 General Membership Query which is identical to a standard - // membership query except the Max Response Time is set to 0, which will tell - // the stack that this is a router using IGMPv1. Send it to the all systems - // group which is the only group this host belongs to. - createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */) - if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { - t.Fatalf("got Membership Queries received = %d, want = 1", got) - } - - // Before advancing the clock, verify that this host has not sent a - // V1MembershipReport yet. - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) - } - - // Verify the solicited Membership Report is sent. Now that this NIC has seen - // an IGMPv1 query, it should send an IGMPv1 Membership Report. - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - - // Cycling the interface should reset the V1 present flag. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } -} - -func TestSendQueuedIGMPReports(t *testing.T) { - e, s, clock := createStack(t, true) - - // Joining a group without an assigned address should queue IGMP packets; none - // should be sent without an assigned address. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) - } - reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport - if got := reportStat.Value(); got != 0 { - t.Errorf("got reportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } - - // The initial set of IGMP reports that were queued should be sent once an - // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) - } - if got := reportStat.Value(); got != 1 { - t.Errorf("got reportStat.Value() = %d, want = 1", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - if got := reportStat.Value(); got != 2 { - t.Errorf("got reportStat.Value() = %d, want = 2", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Should have no more packets to send after the initial set of unsolicited - // reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } -} - -func TestIGMPPacketValidation(t *testing.T) { - tests := []struct { - name string - messageType header.IGMPType - stackAddresses []tcpip.AddressWithPrefix - srcAddr tcpip.Address - includeRouterAlertOption bool - ttl uint8 - expectValidIGMP bool - getMessageTypeStatValue func(tcpip.Stats) uint64 - }{ - { - name: "valid", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "bad ttl", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 2, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "missing router alert ip option", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: false, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "igmp leave group and src ip does not belong to nic subnet", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "igmp query and src ip does not belong to nic subnet", - messageType: header.IGMPMembershipQuery, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, - }, - { - name: "igmp report v1 and src ip does not belong to nic subnet", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "igmp report v2 and src ip does not belong to nic subnet", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "src ip belongs to the subnet of the nic's second address", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{ - {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24}, - {Address: stackAddr, PrefixLen: 24}, - }, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) - for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) - } - } - stats := s.Stats() - // Verify that every relevant stats is zero'd before we send a packet. - if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got) - } - if got := stats.IP.PacketsDelivered.Value(); got != 0 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) - } - createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption) - // We always expect the packet to pass IP validation. - if got := stats.IP.PacketsDelivered.Value(); got != 1 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) - } - // Even when the IGMP-specific validation checks fail, we expect the - // corresponding IGMP counter to be incremented. - if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) - } - var expectedInvalidCount uint64 - if !test.expectValidIGMP { - expectedInvalidCount = 1 - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100644 index 000000000..19b672251 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,113 @@ +// automatically generated by stateify. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError" +} + +func (i *icmpv4DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv4DestinationUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv4DestinationUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv4DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError" +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationHostUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError" +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (e *icmpv4FragmentationNeededSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError" +} + +func (e *icmpv4FragmentationNeededSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + "mtu", + } +} + +func (e *icmpv4FragmentationNeededSockError) beforeSave() {} + +// +checklocksignore +func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError) + stateSinkObject.Save(1, &e.mtu) +} + +func (e *icmpv4FragmentationNeededSockError) afterLoad() {} + +// +checklocksignore +func (e *icmpv4FragmentationNeededSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError) + stateSourceObject.Load(1, &e.mtu) +} + +func init() { + state.Register((*icmpv4DestinationUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationHostUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv4FragmentationNeededSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index 73407be67..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,3351 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/raw" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - extraHeaderReserve = 50 - defaultMTU = 65536 -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) - if testing.Verbose() { - ep = sniffer.New(ep) - } - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - { - err := ep.Connect(randomAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) - } - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -func TestForwarding(t *testing.T) { - const ( - incomingNICID = 1 - outgoingNICID = 2 - randomSequence = 123 - randomIdent = 42 - randomTimeOffset = 0x10203040 - ) - - incomingIPv4Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - } - outgoingIPv4Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), - PrefixLen: 8, - } - outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") - remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") - unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") - multicastIPv4Addr := testutil.MustParse4("225.0.0.0") - linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") - - tests := []struct { - name string - TTL uint8 - sourceAddr tcpip.Address - destAddr tcpip.Address - expectErrorICMP bool - ipFlags uint8 - mtu uint32 - payloadLength int - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code - expectPacketUnrouteableError bool - expectLinkLocalSourceError bool - expectLinkLocalDestError bool - expectPacketForwarded bool - expectedFragmentsForwarded []fragmentInfo - }{ - { - name: "TTL of zero", - TTL: 0, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectErrorICMP: true, - icmpType: header.ICMPv4TimeExceeded, - icmpCode: header.ICMPv4TTLExceeded, - mtu: ipv4.MaxTotalSize, - }, - { - name: "TTL of one", - TTL: 1, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - }, - { - name: "TTL of two", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - }, - { - name: "four EOL options", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "TS type 1 full", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - expectErrorICMP: true, - icmpType: header.ICMPv4ParamProblem, - icmpCode: header.ICMPv4UnusedCode, - }, - { - name: "TS type 0", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - forwardedOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - expectPacketForwarded: true, - }, - { - name: "end of options list", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - 1, 2, 3, 4, - }, - forwardedOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 7 bytes unknown option removed. - 0, 0, 0, 0, - }, - expectPacketForwarded: true, - }, - { - name: "Network unreachable", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: unreachableIPv4Addr, - expectErrorICMP: true, - mtu: ipv4.MaxTotalSize, - icmpType: header.ICMPv4DstUnreachable, - icmpCode: header.ICMPv4NetUnreachable, - expectPacketUnrouteableError: true, - }, - { - name: "Multicast destination", - TTL: 2, - destAddr: multicastIPv4Addr, - expectPacketUnrouteableError: true, - }, - { - name: "Link local destination", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: linkLocalIPv4Addr, - expectLinkLocalDestError: true, - }, - { - name: "Link local source", - TTL: 2, - sourceAddr: linkLocalIPv4Addr, - destAddr: remoteIPv4Addr2, - expectLinkLocalSourceError: true, - }, - { - name: "Fragmentation needed and DF set", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - ipFlags: header.IPv4FlagDontFragment, - // We've picked this MTU because it is: - // - // 1) Greater than the minimum MTU that IPv4 hosts are required to process - // (576 bytes). As per RFC 1812, Section 4.3.2.3: - // - // The ICMP datagram SHOULD contain as much of the original datagram as - // possible without the length of the ICMP datagram exceeding 576 bytes. - // - // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a - // complete ICMP packet on the incoming endpoint (and make assertions about - // it). - // - // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose - // size exceeds the MTU. - mtu: 1000, - payloadLength: 1004, - expectErrorICMP: true, - icmpType: header.ICMPv4DstUnreachable, - icmpCode: header.ICMPv4FragmentationNeeded, - }, - { - name: "Fragmentation needed and DF not set", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: 1000, - payloadLength: 1004, - expectPacketForwarded: true, - // Combined, these fragments have length of 1012 octets, which is equal to - // the length of the payload (1004 octets), plus the length of the ICMP - // header (8 octets). - expectedFragmentsForwarded: []fragmentInfo{ - // The first fragment has a length of the greatest multiple of 8 which is - // less than or equal to to `mtu - header.IPv4MinimumSize`. - {offset: 0, payloadSize: uint16(976), more: true}, - // The next fragment holds the rest of the packet. - {offset: uint16(976), payloadSize: 36, more: false}, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - incomingEndpoint := channel.New(1, test.mtu, "") - if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) - } - incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) - } - - expectedEmittedPacketCount := 1 - if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount { - expectedEmittedPacketCount = len(test.expectedFragmentsForwarded) - } - outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr) - if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) - } - outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: incomingIPv4Addr.Subnet(), - NIC: incomingNICID, - }, - { - Destination: outgoingIPv4Addr.Subnet(), - NIC: outgoingNICID, - }, - }) - - if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) - } - - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - icmpHeaderLength := header.ICMPv4MinimumSize - totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength - hdr := buffer.NewPrependable(totalLength) - hdr.Prepend(test.payloadLength) - icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) - icmpH.SetIdent(randomIdent) - icmpH.SetSequence(randomSequence) - icmpH.SetType(header.ICMPv4Echo) - icmpH.SetCode(header.ICMPv4UnusedCode) - icmpH.SetChecksum(0) - icmpH.SetChecksum(^header.Checksum(icmpH, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLength), - Protocol: uint8(header.ICMPv4ProtocolNumber), - TTL: test.TTL, - SrcAddr: test.sourceAddr, - DstAddr: test.destAddr, - Flags: test.ipFlags, - }) - if len(test.options) != 0 { - ip.SetHeaderLength(uint8(ipHeaderLength)) - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - - reply, ok := incomingEndpoint.Read() - - if test.expectErrorICMP { - if !ok { - t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) - } - - // We expect the ICMP packet to contain as much of the original packet as - // possible up to a limit of 576 bytes, split between payload, IP header, - // and ICMP header. - expectedICMPPayloadLength := func() int { - maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize - maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength - if len(hdr.View()) > maxICMPPayloadLength { - return maxICMPPayloadLength - } - return len(hdr.View()) - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(incomingIPv4Addr.Address), - checker.DstAddr(test.sourceAddr), - checker.TTL(ipv4.DefaultTTL), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(test.icmpType), - checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), - ), - ) - } else if ok { - t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) - } - - if test.expectPacketForwarded { - if len(test.expectedFragmentsForwarded) != 0 { - var fragmentedPackets []*stack.PacketBuffer - for i := 0; i < len(test.expectedFragmentsForwarded); i++ { - reply, ok = outgoingEndpoint.Read() - if !ok { - t.Fatal("expected ICMP Echo fragment through outgoing NIC") - } - fragmentedPackets = append(fragmentedPackets, reply.Pkt) - } - - // The forwarded packet's TTL will have been decremented. - ipHeader := header.IPv4(requestPkt.NetworkHeader().View()) - ipHeader.SetTTL(ipHeader.TTL() - 1) - - // Forwarded packets have available header bytes equalling the sum of the - // maximum IP header size and the maximum size allocated for link layer - // headers. In this case, no size is allocated for link layer headers. - expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { - t.Error(err) - } - } else { - reply, ok = outgoingEndpoint.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(test.sourceAddr), - checker.DstAddr(test.destAddr), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) - } - } else { - if reply, ok = outgoingEndpoint.Read(); ok { - t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) - } - } - boolToInt := func(val bool) uint64 { - if val { - return 1 - } - return 0 - } - - if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { - t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { - t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { - t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want { - t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) - } - }) - } -} - -// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and -// checks the response. -func TestIPv4Sanity(t *testing.T) { - const ( - ttl = 255 - nicID = 1 - randomSequence = 123 - randomIdent = 42 - // In some cases Linux sets the error pointer to the start of the option - // (offset 0) instead of the actual wrong value, which is the length byte - // (offset 1). For compatibility we must do the same. Use this constant - // to indicate where this happens. - pointerOffsetForInvalidLength = 0 - randomTimeOffset = 0x10203040 - ) - var ( - ipv4Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - } - remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) - ) - - tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - options header.IPv4Options - replyOptions header.IPv4Options // reply should look like this - shouldFail bool - expectErrorICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - paramProblemPointer uint8 - }{ - { - name: "valid no options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - }, - { - name: "bad header checksum", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - badHeaderChecksum: true, - shouldFail: true, - }, - // The TTL tests check that we are not rejecting an incoming packet - // with a zero or one TTL, which has been a point of confusion in the - // past as RFC 791 says: "If this field contains the value zero, then the - // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies - // for the case of the destination host, stating as follows. - // - // A host MUST NOT send a datagram with a Time-to-Live (TTL) - // value of zero. - // - // A host MUST NOT discard a datagram just because it was - // received with TTL less than 2. - { - name: "zero TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 0, - }, - { - name: "one TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 1, - }, - { - name: "End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{0, 0, 0, 0}, - replyOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "NOP options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 1, 1}, - replyOptions: header.IPv4Options{1, 1, 1, 1}, - }, - { - name: "NOP and End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 0, 0}, - replyOptions: header.IPv4Options{1, 1, 0, 0}, - }, - { - name: "bad header length", - headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (0)", - maxTotalLength: 0, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip + icmp - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad protocol", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: 99, - TTL: ttl, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4DstUnreachable, - ICMPCode: header.ICMPv4ProtoUnreachable, - }, - { - name: "timestamp option overflow", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - }, - { - name: "timestamp option overflow full", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - // ^ Counter full (15/0xF) - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - replyOptions: header.IPv4Options{}, - }, - { - name: "unknown option", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{10, 4, 9, 0}, - // ^^ - // The unknown option should be stripped out of the reply. - replyOptions: header.IPv4Options{}, - }, - { - name: "bad option - no length", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 1, 1, 1, 68, - // ^-start of timestamp.. but no length.. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - }, - { - name: "bad option - length 0", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 0, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length big", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 9, 9, 0, - // ^ - // There are only 8 bytes allocated to options so 9 bytes of timestamp - // space is not possible. (Second byte) - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - // This tests for some linux compatible behaviour. - // The ICMP pointer returned is 22 for Linux but the - // error is actually in spot 21. - name: "bad option - length bad", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - // Timestamps are in multiples of 4 or 8 but never 7. - // The option space should be padded out. - options: header.IPv4Options{ - 68, 7, 5, 0, - // ^ ^ Linux points here which is wrong. - // | Not a multiple of 4 - 1, 2, 3, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "multiple type 0 with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // The timestamp area is full so add to the overflow count. - name: "multiple type 1 timestamps", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 21, 0x11, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - // Overflow count is the top nibble of the 4th byte. - replyOptions: header.IPv4Options{ - 68, 20, 21, 0x21, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - }, - { - name: "multiple type 1 timestamps with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 28, 21, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 28, 29, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 192, 168, 1, 58, // New IP Address. - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Timestamp pointer uses one based counting so 0 is invalid. - name: "timestamp pointer invalid", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, 0, 0x00, - // ^ 0 instead of 5 or more. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - }, - { - // Timestamp pointer cannot be less than 5. It must point past the header - // which is 4 bytes. (1 based counting) - name: "timestamp pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, - // ^ header is 4 bytes, so 4 should fail. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "valid timestamp pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, - // ^ header is 4 bytes, so 5 should succeed. - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 8, 9, 0x00, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Needs 8 bytes for a type 1 timestamp but there are only 4 free. - name: "bad timer element alignment", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 17, 0x01, - // ^^ ^^ 20 byte area, next free spot at 17. - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - // End of option list with illegal option after it, which should be ignored. - { - name: "end of options list", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 3 bytes unknown option removed. - }, - }, - { - // Timestamp with a size much too small. - name: "timestamp truncated", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 0, 0, - // ^ Smallest possible is 8. Linux points at the 68. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "single record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 4, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "multiple record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 20, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "single record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Unlike timestamp, this should just succeed. - name: "multiple record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 24, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Pointer uses one based counting so 0 is invalid. - name: "record route pointer zero", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 0, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. 3 should fail. - name: "record route pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. Check 4 passes. (Duplicates "single - // record route with room") - name: "valid record route pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Confirm Linux bug for bug compatibility. - // Linux returns slot 22 but the error is in slot 21. - name: "multiple record route with not enough room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 8, // 3 byte header - // ^ ^ Linux points here. We must too. - // | Not enough room. 1 byte free, need 4. - 1, 2, 3, 4, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - name: "duplicate record route", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, 0, // pad - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 7, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) - } - - // Default routes for IPv4 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - if len(test.options)%4 != 0 { - t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) - } - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - - // Specify ident/seq to make sure we get the same in the response. - icmpH.SetIdent(randomIdent) - icmpH.SetSequence(randomSequence) - icmpH.SetType(header.ICMPv4Echo) - icmpH.SetCode(header.ICMPv4UnusedCode) - icmpH.SetChecksum(0) - icmpH.SetChecksum(^header.Checksum(icmpH, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - if test.maxTotalLength < totalLen { - totalLen = test.maxTotalLength - } - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - Protocol: test.transportProtocol, - TTL: test.TTL, - SrcAddr: remoteIPv4Addr, - DstAddr: ipv4Addr.Address, - }) - if test.headerLength != 0 { - ip.SetHeaderLength(test.headerLength) - } else { - // Set the calculated header length, since we may manually add options. - ip.SetHeaderLength(uint8(ipHeaderLength)) - } - if len(test.options) != 0 { - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ipHeaderChecksum := ip.CalculateChecksum() - if test.badHeaderChecksum { - ipHeaderChecksum += 42 - } - ip.SetChecksum(^ipHeaderChecksum) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - reply, ok := e.Read() - if !ok { - if test.shouldFail { - if test.expectErrorICMP { - t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) - } - return // Expected silent failure. - } - t.Fatal("expected ICMP echo reply missing") - } - - // We didn't expect a packet. Register our surprise but carry on to - // provide more information about what we got. - if test.shouldFail && !test.expectErrorICMP { - t.Error("unexpected packet response") - } - - // Check the route that brought the packet to us. - if reply.Route.LocalAddress != ipv4Addr.Address { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) - } - if reply.Route.RemoteAddress != remoteIPv4Addr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) - } - - // Make sure it's all in one buffer for checker. - replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - - // At this stage we only know it's probably an IP+ICMP header so verify - // that much. - checker.IPv4(t, replyIPHeader, - checker.SrcAddr(ipv4Addr.Address), - checker.DstAddr(remoteIPv4Addr), - checker.ICMPv4( - checker.ICMPv4Checksum(), - ), - ) - - // Don't proceed any further if the checker found problems. - if t.Failed() { - t.FailNow() - } - - // OK it's ICMP. We can safely look at the type now. - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - switch replyICMPHeader.Type() { - case header.ICMPv4ParamProblem: - if !test.shouldFail { - t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) - } - if !test.expectErrorICMP { - t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload(hdr.View()), - ), - ) - return - case header.ICMPv4DstUnreachable: - if !test.shouldFail { - t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - if !test.expectErrorICMP { - t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload(hdr.View()), - ), - ) - return - case header.ICMPv4EchoReply: - if test.shouldFail { - if !test.expectErrorICMP { - t.Error("got Echo Reply packet, want no response") - } else { - t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) - } - } - // If the IP options change size then the packet will change size, so - // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - len(test.options) - - checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength+sizeChange), - checker.IPv4Options(test.replyOptions), - checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Seq(randomSequence), - checker.ICMPv4Ident(randomIdent), - ), - ) - default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) - } - }) - } -} - -// compareFragments compares the contents of a set of fragmented packets against -// the contents of a source packet. -// -// If withIPHeader is set to true, we will validate the fragmented packets' IP -// headers against the source packet's IP header. If set to false, we validate -// the fragmented packets' IP headers against each other. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error { - // Make a complete array of the sourcePacket packet. - var source header.IPv4 - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - - // If the packet to be fragmented contains an IPv4 header, use that header for - // validating fragment headers. Else, use the header of the first fragment. - if withIPHeader { - source = header.IPv4(vv.ToView()) - } else { - source = header.IPv4(packets[0].NetworkHeader().View()) - source = append(source, vv.ToView()...) - } - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - // Build up an array of the bytes sent. - var reassembledPayload buffer.VectorisedView - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - fragmentIPHeader := header.IPv4(allBytes.ToView()) - if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) - } - if got := len(fragmentIPHeader); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) - } - if got := fragmentIPHeader.TransportProtocol(); got != proto { - return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) - } - if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) - } - if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes) - } - if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { - return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) - } - if wantFragments[i].more { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) - } - reassembledPayload.AppendView(packet.TransportHeader().View()) - reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView()) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - - // If we are validating against the original IP header, we should exclude the - // ID field, which will only be set fo fragmented packets. - if withIPHeader { - fragmentIPHeader.SetID(0) - fragmentIPHeader.SetChecksum(0) - fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum()) - } - if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - } - - expected := buffer.View(source[source.HeaderLength():]) - if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - transportHeaderLength int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: 1280, - transportHeaderLength: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: 1280, - transportHeaderLength: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 744, more: false}, - }, - }, - { - description: "Fragmented with the minimum mtu", - mtu: header.IPv4MinimumMTU, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv4MinimumMTU + 1, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - transportHeaderLength: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: 1280, - transportHeaderLength: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 44, more: false}, - }, - }, - { - description: "Fragmented with MTU smaller than header", - mtu: 300, - transportHeaderLength: 1000, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 280, more: true}, - {offset: 280, payloadSize: 280, more: true}, - {offset: 560, payloadSize: 280, more: true}, - {offset: 840, payloadSize: 280, more: true}, - {offset: 1120, payloadSize: 280, more: true}, - {offset: 1400, payloadSize: 100, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ttl = 42 - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - source := pkt.Clone() - err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("r.WritePacket(...): %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - writePacketsTests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) - - for _, test := range writePacketsTests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if err != nil { - t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) - } - if n != wantTotalPackets { - t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transportHeaderLength int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag MTU smaller than header", - mtu: 500, - transportHeaderLength: 1000, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 4, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than IPv4 minimum MTU", - mtu: header.IPv4MinimumMTU - 1, - transportHeaderLength: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestInvalidFragments(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" - tos = 0 - ident = 1 - ttl = 48 - protocol = 6 - ) - - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload - } - - type fragmentData struct { - ipv4fields header.IPv4Fields - // 0 means insert the correct IHL. Non 0 means override the correct IHL. - overrideIHL int // For 0 use 1 as it is an int and will be divided by 4. - payload []byte - autoChecksum bool // If true, the Checksum field will be overwritten. - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - }{ - { - name: "IHL and TotalLength zero, FragmentOffset non-zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, - FragmentOffset: 59776, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - name: "IHL and TotalLength zero, FragmentOffset zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - // Payload 17 octets and Fragment offset 65520 - // Leading to the fragment end to be past 65536. - name: "fragment ends past 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 17, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(17), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - // Payload 16 octets and fragment offset 65520 - // Leading to the fragment end to be exactly 65536. - name: "fragment ends exactly at 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(16), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 0, - wantMalformedFragments: 0, - }, - { - name: "IHL less than IPv4 minimum size", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 1944, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize - 12, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 2, - wantMalformedFragments: 0, - }, - { - name: "fragment with short TotalLength and extra payload", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 28816, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 4, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - name: "multiple fragments with More Fragments flag set to false", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 128, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - }) - e := channel.New(0, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize + len(f.payload) - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got { - t.Fatalf("copied %d bytes, expected %d bytes.", got, want) - } - // Encode sets this up correctly. If we want a different value for - // testing then we need to overwrite the good value. - if f.overrideIHL != 0 { - ip.SetHeaderLength(uint8(f.overrideIHL)) - // If we are asked to add options (type not specified) then pad - // with 0 (EOL). RFC 791 page 23 says "The padding is zero". - for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ { - ip[i] = byte(header.IPv4OptionListEndType) - } - } - - if f.autoChecksum { - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - } - - vv := hdr.View().ToVectorisedView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) - } - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" - tos = 0 - ident = 1 - ttl = 48 - protocol = 99 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv4fields header.IPv4Fields - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - Clock: clock, - }) - e := channel.New(1, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && ip.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(header.IPv4ProtocolNumber, pkt) - } - - clock.Advance(ipv4.ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4TimeExceeded), - checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), - checker.ICMPv4Checksum(), - checker.ICMPv4Payload(firstFragmentSent), - ), - ) - }) - } -} - -// TestReceiveFragments feeds fragments in through the incoming packet path to -// test reassembly -func TestReceiveFragments(t *testing.T) { - const ( - nicID = 1 - - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 - ) - - // Build and return a UDP header containing payload. - udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { - payload := buffer.NewView(payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + len(payload) - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - // UDP header plus a payload of 0..256 - ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) - udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] - ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) - udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 2. - ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) - udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 3. - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 791 section 3.1 page 14). - ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) - udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled IPv4 payload length. - ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) - udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] - - type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - id uint16 - flags uint8 - fragmentOffset uint16 - payload buffer.View - } - - tests := []struct { - name string - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "No fragmentation with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "More fragments without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Non-zero fragment offset without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 8, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload3Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:63], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 63, - payload: ipv4Payload3Addr1ToAddr2[63:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Second fragment has MoreFlags set", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload2Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload2Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr3ToAddr2[:32], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 32, - payload: ipv4Payload1Addr3ToAddr2[32:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - { - name: "Fragment without followup", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Setup a stack and endpoint. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - RawFactory: raw.EndpointFactory{}, - }) - e := channel.New(0, 1280, "\xf0\x00") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - // Bring up a raw endpoint so we can examine network headers. - epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) - if err != nil { - t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer epRaw.Close() - - // Prepare and send the fragments. - for _, frag := range test.fragments { - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - - // Serialize IPv4 fixed header. - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), - ID: frag.id, - Flags: frag.flags, - FragmentOffset: frag.fragmentOffset, - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: frag.srcAddr, - DstAddr: frag.dstAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(frag.payload) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, expectedPayload := range test.expectedPayloads { - // Check UDP payload delivered by UDP endpoint. - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) ep.Read: %s", i, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(expectedPayload), - Total: len(expectedPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) - } - if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) - } - - // Check IPv4 header in packet delivered by raw endpoint. - buf.Reset() - result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) epRaw.Read: %s", i, err) - } - // Reassambly does not take care of checksum. Here we write our own - // check routine instead of using checker.IPv4. - ip := header.IPv4(buf.Bytes()) - for _, check := range []checker.NetworkChecker{ - checker.FragmentFlags(0), - checker.FragmentOffset(0), - checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), - } { - check(t, []header.Network{ip}) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectOutputDropped int - expectPostroutingDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectOutputDropped: 0, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectOutputDropped: 0, - expectPostroutingDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all with Output chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectOutputDropped: nPackets, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Drop all with Postrouting chain", - setup: func(t *testing.T, stk *stack.Stack) { - ipt := stk.IPTables() - filter := ipt.GetTable(stack.NATID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectOutputDropped: 0, - expectPostroutingDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some with Output chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectOutputDropped: 1, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Drop some with Postrouting chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Postrouting DROP rule that matches only 1 - // of the 3 packets. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.NATID, false /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectOutputDropped: 0, - expectPostroutingDropped: 1, - expectWritten: nPackets, - }, - } - - // Parameterize the tests to run with both WritePacket and WritePackets. - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { - t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) - } - if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { - t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) - } - if nWritten != test.expectWritten { - t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) - } - { - mask := tcpip.AddressMask(header.IPv4Broadcast) - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func TestPacketQueuing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 8, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(nil, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, - TTL: ipv4.DefaultTTL, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply), - checker.ICMPv4Code(header.ICMPv4UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(1, defaultMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: clock, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a ARP request since link address resolution should be - // performed. - { - clock.RunImmediatelyScheduledJobs() - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != arp.ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) - } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - rep := header.ARP(p.Pkt.NetworkHeader().View()) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) - } - } - - // Send an ARP reply to complete link address resolution. - { - hdr := buffer.View(make([]byte, header.ARPSize)) - packet := header.ARP(hdr) - packet.SetIPv4OverEthernet() - packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), host2NICLinkAddr) - copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) - copy(packet.HardwareAddressTarget(), host1NICLinkAddr) - copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) - e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - clock.RunImmediatelyScheduledJobs() - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} - -// TestCloseLocking test that lock ordering is followed when closing an -// endpoint. -func TestCloseLocking(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - iterations = 1000 - ) - - var ( - src = testutil.MustParse4("16.0.0.1") - dst = testutil.MustParse4("16.0.0.2") - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - // Perform NAT so that the endpoint tries to search for a sibling endpoint - // which ends up taking the protocol and endpoint lock (in that order). - table := stack.Table{ - Rules: []stack.Rule{ - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 1, - stack.Forward: stack.HookUnset, - stack.Output: 2, - stack.Postrouting: 3, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 1, - stack.Forward: stack.HookUnset, - stack.Output: 2, - stack.Postrouting: 3, - }, - } - if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { - t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) - } - - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }}) - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} - if err := ep.Connect(addr); err != nil { - t.Errorf("ep.Connect(%#v): %s", addr, err) - } - - var wg sync.WaitGroup - defer wg.Wait() - - // Writing packets should trigger NAT which requires the stack to search the - // protocol for network endpoints with the destination address. - // - // Creating and removing interfaces should modify the protocol and endpoint - // which requires taking the locks of each. - // - // We expect the protocol > endpoint lock ordering to be followed here. - wg.Add(2) - go func() { - defer wg.Done() - - data := []byte{1, 2, 3, 4} - - for i := 0; i < iterations; i++ { - var r bytes.Reader - r.Reset(data) - if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Errorf("ep.Write(_, _): %s", err) - return - } else if want := int64(len(data)); n != want { - t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) - return - } - } - }() - go func() { - defer wg.Done() - - for i := 0; i < iterations; i++ { - if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil { - t.Errorf("CreateNIC(%d, _): %s", nicID2, err) - return - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Errorf("RemoveNIC(%d): %s", nicID2, err) - return - } - } - }() -} diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go deleted file mode 100644 index d1f9e3cf5..000000000 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4 - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEP := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD deleted file mode 100644 index f99cbf8f3..000000000 --- a/pkg/tcpip/network/ipv6/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv6", - srcs = [ - "dhcpv6configurationfromndpra_string.go", - "icmp.go", - "ipv6.go", - "mld.go", - "ndp.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv6_test", - size = "small", - srcs = [ - "icmp_test.go", - "ipv6_test.go", - "ndp_test.go", - ], - library = ":ipv6", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "ipv6_x_test", - size = "small", - srcs = ["mld_test.go"], - deps = [ - ":ipv6", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - ], -) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go deleted file mode 100644 index 7c2a3e56b..000000000 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ /dev/null @@ -1,1730 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "net" - "reflect" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - nicID = 1 - - linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") - - defaultChannelSize = 1 - defaultMTU = 65536 - - arbitraryHopLimit = 42 -) - -var ( - lladdr0 = header.LinkLocalAddr(linkAddr0) - lladdr1 = header.LinkLocalAddr(linkAddr1) -) - -type stubLinkEndpoint struct { - stack.LinkEndpoint -} - -func (*stubLinkEndpoint) MTU() uint32 { - return defaultMTU -} - -func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - // Indicate that resolution for link layer addresses is required to send - // packets over this link. This is needed so the NIC knows to allocate a - // neighbor table. - return stack.CapabilityResolutionRequired -} - -func (*stubLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - return nil -} - -func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -type stubDispatcher struct { - stack.TransportDispatcher -} - -func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { - return stack.TransportPacketHandled -} - -func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { - // No-op. -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.LinkEndpoint - - probeCount int - confirmationCount int - - nicID tcpip.NICID -} - -func (*testInterface) ID() tcpip.NICID { - return nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (*testInterface) Spoofing() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), pkts, protocol) -} - -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr - return t.LinkEndpoint.WritePacket(r, protocol, pkt) -} - -func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { - t.probeCount++ - return nil -} - -func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { - t.confirmationCount++ - return nil -} - -func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { - return tcpip.AddressWithPrefix{}, nil -} - -func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { - return false -} - -func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6, hopLimit uint8, includeRouterAlert bool) { - var extensionHeaders header.IPv6ExtHdrSerializer - if includeRouterAlert { - extensionHeaders = header.IPv6ExtHdrSerializer{ - header.IPv6SerializableHopByHopExtHdr{ - &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, - }, - } - } - ip := buffer.NewView(header.IPv6MinimumSize + extensionHeaders.Length()) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: src, - DstAddr: dst, - ExtensionHeaders: extensionHeaders, - }) - - vv := ip.ToVectorisedView() - vv.AppendView(buffer.View(icmp)) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) -} - -func TestICMPCounts(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - hopLimit uint8 - includeRouterAlert bool - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, - { - typ: header.ICMPv6NeighborAdvert, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6MulticastListenerQuery, - hopLimit: header.MLDHopLimit, - includeRouterAlert: true, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerReport, - hopLimit: header.MLDHopLimit, - includeRouterAlert: true, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerDone, - hopLimit: header.MLDHopLimit, - includeRouterAlert: true, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: 255, /* Unrecognized */ - size: 50, - }, - } - - for _, typ := range types { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], - Src: lladdr0, - Dst: lladdr1, - PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */), - PayloadLen: len(typ.extraData), - })) - handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)), arbitraryHopLimit, false) - - icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } -} - -func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { - t := v.Type() - for i := 0; i < v.NumField(); i++ { - v := v.Field(i) - if s, ok := v.Interface().(*tcpip.StatCounter); ok { - f(t.Field(i).Name, s) - } else { - visitStats(v, f) - } - } -} - -type testContext struct { - s0 *stack.Stack - s1 *stack.Stack - - linkEP0 *channel.Endpoint - linkEP1 *channel.Endpoint - - clock *faketime.ManualClock -} - -type endpointWithResolutionCapability struct { - stack.LinkEndpoint -} - -func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { - return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired -} - -func newTestContext(t *testing.T) *testContext { - clock := faketime.NewManualClock() - c := &testContext{ - s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - Clock: clock, - }), - s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - Clock: clock, - }), - clock: clock, - } - - c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) - - wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) - if testing.Verbose() { - wrappedEP0 = sniffer.New(wrappedEP0) - } - if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { - t.Fatalf("CreateNIC s0: %v", err) - } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) - } - - c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) - wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) - } - - subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - c.s0.SetRouteTable( - []tcpip.Route{{ - Destination: subnet0, - NIC: nicID, - }}, - ) - subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) - if err != nil { - t.Fatal(err) - } - c.s1.SetRouteTable( - []tcpip.Route{{ - Destination: subnet1, - NIC: nicID, - }}, - ) - - t.Cleanup(func() { - if err := c.s0.RemoveNIC(nicID); err != nil { - t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err) - } - if err := c.s1.RemoveNIC(nicID); err != nil { - t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err) - } - - c.linkEP0.Close() - c.linkEP1.Close() - }) - - return c -} - -type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type - remoteLinkAddr tcpip.LinkAddress -} - -func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) { - t.Helper() - - clock.RunImmediatelyScheduledJobs() - pi, ok := args.src.Read() - if !ok { - t.Fatal("packet didn't arrive") - } - - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), - }) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) - } - - if pi.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pi.Proto) - return - } - - if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) - } - - // Pull the full payload since network header. Needed for header.IPv6 to - // extract its payload. - ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) - transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) - if transProto != header.ICMPv6ProtocolNumber { - t.Errorf("unexpected transport protocol number %d", transProto) - return - } - icmpv6 := header.ICMPv6(ipv6.Payload()) - if got, want := icmpv6.Type(), args.typ; got != want { - t.Errorf("got ICMPv6 type = %d, want = %d", got, want) - return - } - if fn != nil { - fn(t, icmpv6) - } -} - -func TestLinkResolution(t *testing.T) { - c := newTestContext(t) - - r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), - })) - - // We can't send our payload directly over the route because that - // doesn't provoke NDP discovery. - var wq waiter.Queue - ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) - } - - { - var r bytes.Reader - r.Reset(hdr.View()) - if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) - } - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) - } - - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, - } { - routeICMPv6Packet(t, c.clock, args, nil) - } -} - -func TestICMPChecksumValidationSimple(t *testing.T) { - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "DstUnreachable", - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - }, - { - name: "PacketTooBig", - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - }, - { - name: "TimeExceeded", - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - }, - { - name: "ParamProblem", - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - }, - { - name: "EchoRequest", - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - }, - { - name: "EchoReply", - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - }, - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - // Hosts MUST silently discard any received Router Solicitation messages. - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } - t.Run(name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr0) - - // Indicate that resolution for link layer addresses is required to - // send packets over this link. This is needed so the NIC knows to - // allocate a neighbor table. - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if isRouter { - if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) - } - } - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: lladdr1, - Dst: lladdr0, - })) - } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Router only count should not have increased. - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if !isRouter && typ.routerOnly { - // Router only count should have increased. - if got := routerOnly.Value(); got != 1 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) - } - } - }) - } - } -} - -func TestICMPChecksumValidationWithPayload(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - icmpSize := size + payloadSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) - icmpHdr.SetType(typ) - payloadFn(icmpHdr.Payload()) - - if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: lladdr1, - Dst: lladdr0, - })) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got = %d, want = 0", got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - icmpHdr := header.ICMPv6(hdr.Prepend(size)) - icmpHdr.SetType(typ) - - payload := buffer.NewView(payloadSize) - payloadFn(payload) - - if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: lladdr1, - Dst: lladdr0, - PayloadCsum: header.Checksum(payload, 0 /* initial */), - PayloadLen: len(payload), - })) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got = %d, want = 0", got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestLinkAddressRequest(t *testing.T) { - const nicID = 1 - - snaddr := header.SolicitedNodeAddr(lladdr0) - mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) - - tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - - expectedErr tcpip.Error - expectedRemoteAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - }{ - { - name: "Unicast", - nicAddr: lladdr1, - localAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedRemoteAddr: lladdr0, - expectedRemoteLinkAddr: linkAddr1, - }, - { - name: "Multicast", - nicAddr: lladdr1, - localAddr: lladdr1, - remoteLinkAddr: "", - expectedRemoteAddr: snaddr, - expectedRemoteLinkAddr: mcaddr, - }, - { - name: "Unicast with unspecified source", - nicAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedRemoteAddr: lladdr0, - expectedRemoteLinkAddr: linkAddr1, - }, - { - name: "Multicast with unspecified source", - nicAddr: lladdr1, - remoteLinkAddr: "", - expectedRemoteAddr: snaddr, - expectedRemoteLinkAddr: mcaddr, - }, - { - name: "Unicast with unassigned address", - localAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "Multicast with unassigned address", - localAddr: lladdr1, - remoteLinkAddr: "", - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "Unicast with no local address available", - remoteLinkAddr: linkAddr1, - expectedErr: &tcpip.ErrNetworkUnreachable{}, - }, - { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: &tcpip.ErrNetworkUnreachable{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) - } - linkRes, ok := ep.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) - } - - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) - } - } - - { - err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) - } - } - - if test.expectedErr != nil { - return - } - - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = test.expectedRemoteLinkAddr - if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) - }) - } -} - -func TestPacketQueing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - host2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(nil, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: udp.ProtocolNumber, - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: host2IPv6Addr.AddressWithPrefix.Address, - Dst: host1IPv6Addr.AddressWithPrefix.Address, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply), - checker.ICMPv6Code(header.ICMPv6UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: clock, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a neighbor solicitation since link address resolution should - // be performed. - { - clock.RunImmediatelyScheduledJobs() - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), - )) - } - - // Send a neighbor advertisement to complete link address resolution. - { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), - }) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: host2IPv6Addr.AddressWithPrefix.Address, - Dst: host1IPv6Addr.AddressWithPrefix.Address, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - clock.RunImmediatelyScheduledJobs() - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} - -func TestCallsToNeighborCache(t *testing.T) { - tests := []struct { - name string - createPacket func() header.ICMPv6 - multicast bool - source tcpip.Address - destination tcpip.Address - wantProbeCount int - wantConfirmationCount int - }{ - { - name: "Unicast Neighbor Solicitation without source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - return icmp - }, - source: lladdr1, - destination: lladdr0, - // "The source link-layer address option SHOULD be included in unicast - // solicitations." - RFC 4861 section 4.3 - // - // A Neighbor Advertisement needs to be sent in response, but the - // Neighbor Cache shouldn't be updated since we have no useful - // information about the sender. - wantProbeCount: 0, - }, - { - name: "Unicast Neighbor Solicitation with source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: lladdr0, - wantProbeCount: 1, - }, - { - name: "Multicast Neighbor Solicitation without source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - return icmp - }, - source: lladdr1, - destination: header.SolicitedNodeAddr(lladdr0), - // "The source link-layer address option MUST be included in multicast - // solicitations." - RFC 4861 section 4.3 - wantProbeCount: 0, - }, - { - name: "Multicast Neighbor Solicitation with source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: header.SolicitedNodeAddr(lladdr0), - wantProbeCount: 1, - }, - { - name: "Unicast Neighbor Advertisement without target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - return icmp - }, - source: lladdr1, - destination: lladdr0, - // "When responding to unicast solicitations, the target link-layer - // address option can be omitted since the sender of the solicitation has - // the correct link-layer address; otherwise, it would not be able to - // send the unicast solicitation in the first place." - // - RFC 4861 section 4.4 - wantConfirmationCount: 1, - }, - { - name: "Unicast Neighbor Advertisement with target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: lladdr0, - wantConfirmationCount: 1, - }, - { - name: "Multicast Neighbor Advertisement without target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(false) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - return icmp - }, - source: lladdr1, - destination: header.IPv6AllNodesMulticastAddress, - // "Target link-layer address MUST be included for multicast solicitations - // in order to avoid infinite Neighbor Solicitation "recursion" when the - // peer node does not have a cache entry to return a Neighbor - // Advertisements message." - RFC 4861 section 4.4 - wantConfirmationCount: 0, - }, - { - name: "Multicast Neighbor Advertisement with target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(false) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: header.IPv6AllNodesMulticastAddress, - wantConfirmationCount: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)} - ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: test.source, - Dst: test.destination, - })) - handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false) - - // Confirm the endpoint calls the correct NUDHandler method. - if testInterface.probeCount != test.wantProbeCount { - t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount) - } - if testInterface.confirmationCount != test.wantConfirmationCount { - t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go new file mode 100644 index 000000000..13d427822 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go @@ -0,0 +1,136 @@ +// automatically generated by stateify. + +package ipv6 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv6DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationUnreachableSockError" +} + +func (i *icmpv6DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv6DestinationUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv6DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv6DestinationUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv6DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationNetworkUnreachableSockError" +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv6DestinationNetworkUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv6DestinationNetworkUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationPortUnreachableSockError" +} + +func (i *icmpv6DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationPortUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv6DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationPortUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv6DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationAddressUnreachableSockError" +} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationAddressUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv6DestinationAddressUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationAddressUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv6DestinationAddressUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (e *icmpv6PacketTooBigSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6PacketTooBigSockError" +} + +func (e *icmpv6PacketTooBigSockError) StateFields() []string { + return []string{ + "mtu", + } +} + +func (e *icmpv6PacketTooBigSockError) beforeSave() {} + +// +checklocksignore +func (e *icmpv6PacketTooBigSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.mtu) +} + +func (e *icmpv6PacketTooBigSockError) afterLoad() {} + +// +checklocksignore +func (e *icmpv6PacketTooBigSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.mtu) +} + +func init() { + state.Register((*icmpv6DestinationUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationNetworkUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationAddressUnreachableSockError)(nil)) + state.Register((*icmpv6PacketTooBigSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go deleted file mode 100644 index d2a23fd4f..000000000 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ /dev/null @@ -1,3491 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // The least significant 3 bytes are the same as addr2 so both addr2 and - // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" - - // Tests use the extension header identifier values as uint8 instead of - // header.IPv6ExtensionHeaderIdentifier. - hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) - routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier) - fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) - destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) - noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) - unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) - - extraHeaderReserve = 50 -) - -// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the -// expected Neighbor Advertisement received count after receiving the packet. -func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: src, - Dst: dst, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stats := s.Stats().ICMP.V6.PacketsReceived - - if got := stats.NeighborAdvert.Value(); got != want { - t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) - } -} - -// testReceiveUDP tests receiving a UDP packet from src to dst. want is the -// expected UDP received count after receiving the packet. -func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - - // Receive UDP Packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - - // UDP pseudo-header checksum. - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) - - // UDP checksum - sum = header.Checksum(nil, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stat := s.Stats().UDP.PacketsReceived - - if got := stat.Value(); got != want { - t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) - } -} - -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { - // sourcePacket does not have its IP Header populated. Let's copy the one - // from the first fragment. - source := header.IPv6(packets[0].NetworkHeader().View()) - sourceIPHeadersLen := len(source) - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) - - var reassembledPayload buffer.VectorisedView - for i, fragment := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) - fragmentIPHeaders := header.IPv6(allBytes.ToView()) - if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) - } - - fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() - if fragmentIPHeadersLength != sourceIPHeadersLen { - return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) - } - - if got := len(fragmentIPHeaders); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) - } - - sourceIPHeader := source[:header.IPv6MinimumSize] - fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] - - if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { - return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) - } - - // We expect the IPv6 Header to be similar across each fragment, besides the - // payload length. - sourceIPHeader.SetPayloadLength(0) - fragmentIPHeader.SetPayloadLength(0) - if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - - if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } - if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) - } - - if len(packets) > 1 { - // If the source packet was big enough that it needed fragmentation, let's - // inspect the fragment header. Because no other extension headers are - // supported, it will always be the last extension header. - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) - - if got := fragmentHeader.More(); got != wantFragments[i].more { - return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) - } - if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { - return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) - } - if got := fragmentHeader.NextHeader(); got != uint8(proto) { - return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) - } - } - - // Store the reassembled payload as we parse each fragment. The payload - // includes the Transport header and everything after. - reassembledPayload.AppendView(fragment.TransportHeader().View()) - reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView()) - } - - if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and -// UDP packets destined to the IPv6 link-local all-nodes multicast address. -func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocolFactory - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6, testReceiveICMP}, - {"UDP", udp.NewProtocol, testReceiveUDP}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, - }) - e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should receive a packet destined to the all-nodes - // multicast address. - test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) - }) - } -} - -// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP -// packets destined to the IPv6 solicited-node address of an assigned IPv6 -// address. -func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocolFactory - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6, testReceiveICMP}, - {"UDP", udp.NewProtocol, testReceiveUDP}, - } - - snmc := header.SolicitedNodeAddr(addr2) - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, - }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - // Should not receive a packet destined to the solicited node address of - // addr2/addr3 yet as we haven't added those addresses. - test.rxf(t, s, e, addr1, snmc, 0) - - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - // Should receive a packet destined to the solicited node address of - // addr2/addr3 now that we have added added addr2. - test.rxf(t, s, e, addr1, snmc, 1) - - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) - } - - // Should still receive a packet destined to the solicited node address of - // addr2/addr3 now that we have added addr3. - test.rxf(t, s, e, addr1, snmc, 2) - - if err := s.RemoveAddress(nicID, addr2); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) - } - - // Should still receive a packet destined to the solicited node address of - // addr2/addr3 now that we have removed addr2. - test.rxf(t, s, e, addr1, snmc, 3) - - // Make sure addr3's endpoint does not get removed from the NIC by - // incrementing its reference count with a route. - r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) - } - defer r.Release() - - if err := s.RemoveAddress(nicID, addr3); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) - } - - // Should not receive a packet destined to the solicited node address of - // addr2/addr3 yet as both of them got removed, even though a route using - // addr3 exists. - test.rxf(t, s, e, addr1, snmc, 3) - }) - } -} - -// TestAddIpv6Address tests adding IPv6 addresses. -func TestAddIpv6Address(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - addr tcpip.Address - }{ - // This test is in response to b/140943433. - { - "Nil", - tcpip.Address([]byte(nil)), - }, - { - "ValidUnicast", - addr1, - }, - { - "ValidLinkLocalUnicast", - lladdr0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) - } - - if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err) - } else if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr) - } - }) - } -} - -func TestReceiveIPv6ExtHdrs(t *testing.T) { - tests := []struct { - name string - extHdr func(nextHdr uint8) ([]byte, uint8) - shouldAccept bool - countersToBeIncremented func(*tcpip.Stats) []*tcpip.StatCounter - // Should we expect an ICMP response and if so, with what contents? - expectICMP bool - ICMPType header.ICMPv6Type - ICMPCode header.ICMPv6Code - pointer uint32 - multicast bool - }{ - { - name: "None", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "hopbyhop with router alert option", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - }, hopByHopExtHdrID - }, - shouldAccept: true, - countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter { - return []*tcpip.StatCounter{stats.IP.OptionRouterAlertReceived} - }, - }, - { - name: "hopbyhop with two router alert options", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, 0, 0, - }, hopByHopExtHdrID - }, - shouldAccept: false, - countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter { - return []*tcpip.StatCounter{ - stats.IP.OptionRouterAlertReceived, - stats.IP.MalformedPacketsReceived, - } - }, - }, - { - name: "hopbyhop with unknown option skippable action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "hopbyhop with unknown option discard action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "hopbyhop with unknown option discard and send icmp action (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: false, - }, - { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 0, 2, 3, 4, 5, - }, routingExtHdrID - }, - shouldAccept: true, - }, - { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 1, 2, 3, 4, 5, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6ErroneousHeader, - pointer: header.IPv6FixedHeaderSize + 2, - }, - { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 0, 0, 0, 0, 0, 0, - }, fragmentExtHdrID - }, - shouldAccept: true, - }, - { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 0, 0, 1, 2, 3, 4, - }, fragmentExtHdrID - }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 0, 1, 2, 3, 4, - }, fragmentExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return nil, noNextHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "unknown next header (first)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, unknownHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6NextHeaderOffset, - }, - { - name: "unknown next header (not first)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - unknownHdrID, 0, - 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "destination with unknown option skippable action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, destinationExtHdrID - }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "destination with unknown option discard action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "destination with unknown option discard and send icmp action (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ 191 is an unknown option. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action (muilticast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ 191 is an unknown option. - }, destinationExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ 255 is unknown. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ 255 is unknown. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: false, - multicast: true, - }, - { - name: "atomic fragment - routing", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Fragment extension header. - routingExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Routing extension header. - nextHdr, 0, 1, 0, 2, 3, 4, 5, - }, fragmentExtHdrID - }, - shouldAccept: true, - }, - { - name: "hop by hop (with skippable unknown) - routing", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - nextHdr, 0, 1, 0, 2, 3, 4, 5, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "routing - hop by hop (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Routing extension header. - hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, - // ^^^ The HopByHop extension header may not appear after the first - // extension header. - - // Hop By Hop extension header with skippable unknown option. - nextHdr, 0, 62, 4, 1, 2, 3, 4, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "routing - hop by hop (with send icmp unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Routing extension header. - hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, - // ^^^ The HopByHop extension header may not appear after the first - // extension header. - - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with skippable unknown option. - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with discard action for unknown option. - routingExtHdrID, 0, 65, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with skippable unknown option. - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with discard action for unknown - // option. - nextHdr, 0, 65, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - // Add a default route so that a return packet knows where to go. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.WritableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8} - udpLength := header.UDPMinimumSize + len(udpPayload) - extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber)) - extHdrLen := len(extHdrBytes) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength) - - // Serialize UDP message. - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), udpPayload) - - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } - - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) - sum = header.Checksum(udpPayload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - // Copy extension header bytes between the UDP message and the IPv6 - // fixed header. - copy(hdr.Prepend(extHdrLen), extHdrBytes) - - // Serialize IPv6 fixed header. - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - // We're lying about transport protocol here to be able to generate - // raw extension headers from the test definitions. - TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, - }) - - stats := s.Stats() - var counters []*tcpip.StatCounter - // Make sure that the counters we expect to be incremented are initially - // set to zero. - if fn := test.countersToBeIncremented; fn != nil { - counters = fn(&stats) - } - for i := range counters { - if got := counters[i].Value(); got != 0 { - t.Errorf("before writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 0", i, got) - } - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - for i := range counters { - if got := counters[i].Value(); got != 1 { - t.Errorf("after writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 1", i, got) - } - } - - udpReceiveStat := stats.UDP.PacketsReceived - if !test.shouldAccept { - if got := udpReceiveStat.Value(); got != 0 { - t.Errorf("got UDP Rx Packets = %d, want = 0", got) - } - - if !test.expectICMP { - if p, ok := e.Read(); ok { - t.Fatalf("unexpected packet received: %#v", p) - } - return - } - - // ICMP required. - p, ok := e.Read() - if !ok { - t.Fatalf("expected packet wasn't written out") - } - - // Pack the output packet into a single buffer.View as the checkers - // assume that. - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { - t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) - } - - ipHdr := header.IPv6(pkt) - checker.IPv6(t, ipHdr, checker.ICMPv6( - checker.ICMPv6Type(test.ICMPType), - checker.ICMPv6Code(test.ICMPCode))) - - // We know we are looking at no extension headers in the error ICMP - // packets. - icmpPkt := header.ICMPv6(ipHdr.Payload()) - // We know we sent small packets that won't be truncated when reflected - // back to us. - originalPacket := icmpPkt.Payload() - if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { - t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) - } - if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { - t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) - } - return - } - - // Expect a UDP packet. - if got := udpReceiveStat.Value(); got != 1 { - t.Errorf("got UDP Rx Packets = %d, want = 1", got) - } - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read: %s", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(udpPayload), - Total: len(udpPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - - // Should not have any more UDP packets. - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// fragmentData holds the IPv6 payload for a fragmented IPv6 packet. -type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - nextHdr uint8 - data buffer.VectorisedView -} - -func TestReceiveIPv6Fragments(t *testing.T) { - const ( - udpPayload1Length = 256 - udpPayload2Length = 128 - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 - fragmentExtHdrLen = 8 - // Note, not all routing extension headers will be 8 bytes but this test - // uses 8 byte routing extension headers for most sub tests. - routingExtHdrLen = 8 - ) - - udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View { - payloadLen := len(payload) - for i := 0; i < payloadLen; i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + payloadLen - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte - udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:] - ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2) - - var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte - udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:] - ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2) - - var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte - udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:] - ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2) - - var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte - udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] - ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) - - var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte - udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] - ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) - - tests := []struct { - name string - expectedPayload []byte - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: uint8(header.UDPProtocolNumber), - data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Atomic fragment", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Atomic fragment with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, - - ipv6Payload3Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with different Next Header values", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - // NextHeader value is different than the one in the first fragment, so - // this NextHeader should be ignored. - []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload3Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload3Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+63, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload3Addr1ToAddr2[:63], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload3Addr1ToAddr2[63:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 2 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, - udpMaximumSizeMinus15 >> 8, - udpMaximumSizeMinus15 & 0xff, - 0, 0, 0, 1}, - - ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, - udpMaximumSizeMinus15 >> 8, - (udpMaximumSizeMinus15 & 0xff) + 1, - 0, 0, 0, 1}, - - ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with per-fragment routing header with zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 0. - []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, - - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 0. - []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, - - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with per-fragment routing header with non-zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 1. - []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, - - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 1. - []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, - - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, - - // Routing extension header. - // - // Segments left = 0. - []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with routing header with non-zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, - - // Routing extension header. - // - // Segments left = 1. - []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with zero segments left across fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is fragmentExtHdrLen+8 because the - // first 8 bytes of the 16 byte routing extension header is in - // this fragment. - fragmentExtHdrLen+8, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, - - // Routing extension header (part 1) - // - // Segments left = 0. - []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}, - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of - // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 1, More = false, ID = 1 - []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, - - // Routing extension header (part 2) - []byte{6, 7, 8, 9, 10, 11, 12, 13}, - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with non-zero segments left across fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is fragmentExtHdrLen+8 because the - // first 8 bytes of the 16 byte routing extension header is in - // this fragment. - fragmentExtHdrLen+8, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, - - // Routing extension header (part 1) - // - // Segments left = 1. - []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}, - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of - // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 1, More = false, ID = 1 - []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, - - // Routing extension header (part 2) - []byte{6, 7, 8, 9, 10, 11, 12, 13}, - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: nil, - }, - // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal" - // fragmented traffic. - { - name: "Two fragments with atomic", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - // This fragment has the same ID as the other fragments but is an atomic - // fragment. It should not interfere with the other fragments. - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}, - - ipv6Payload2Addr1ToAddr2, - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 2 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}, - - ipv6Payload2Addr1ToAddr2[:32], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 4, More = false, ID = 2 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}, - - ipv6Payload2Addr1ToAddr2[32:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr3, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, - - ipv6Payload1Addr3ToAddr2[:32], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr3, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 4, More = false, ID = 1 - []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}, - - ipv6Payload1Addr3ToAddr2[32:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) - - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - // We're lying about transport protocol here so that we can generate - // raw extension headers for the tests. - TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) - - vv := hdr.View().ToVectorisedView() - vv.Append(f.data) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, p := range test.expectedPayloads { - var buf bytes.Buffer - _, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) Read: %s", i, err) - } - if diff := cmp.Diff(p, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestInvalidIPv6Fragments(t *testing.T) { - const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - nicID = 1 - hoplimit = 255 - ident = 1 - data = "TEST_INVALID_IPV6_FRAGMENTS" - ) - - type fragmentData struct { - ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6SerializableFragmentExtHdr - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - expectICMP bool - expectICMPType header.ICMPv6Type - expectICMPCode header.ICMPv6Code - expectICMPTypeSpecific uint32 - }{ - { - name: "fragment size is not a multiple of 8 and the M flag is true", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0 >> 3, - M: true, - Identification: ident, - }, - payload: []byte(data)[:9], - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - expectICMP: true, - expectICMPType: header.ICMPv6ParamProblem, - expectICMPCode: header.ICMPv6ErroneousHeader, - expectICMPTypeSpecific: header.IPv6PayloadLenOffset, - }, - { - name: "fragments reassembled into a payload exceeding the max IPv6 payload size", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, - M: false, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - expectICMP: true, - expectICMPType: header.ICMPv6ParamProblem, - expectICMPCode: header.ICMPv6ErroneousHeader, - expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - NewProtocol, - }, - }) - e := channel.New(1, 1500, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }}) - - var expectICMPPayload buffer.View - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - encodeArgs := f.ipv6Fields - encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) - ip.Encode(&encodeArgs) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if test.expectICMP { - expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(ProtocolNumber, pkt) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) - } - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), - checker.ICMPv6( - checker.ICMPv6Type(test.expectICMPType), - checker.ICMPv6Code(test.expectICMPCode), - checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), - checker.ICMPv6Payload(expectICMPPayload), - ), - ) - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - nicID = 1 - hoplimit = 255 - ident = 1 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6SerializableFragmentExtHdr - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - NewProtocol, - }, - Clock: clock, - }) - - e := channel.New(1, 1500, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - encodeArgs := f.ipv6Fields - encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) - ip.Encode(&encodeArgs) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(ProtocolNumber, pkt) - } - - clock.Advance(ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), - checker.ICMPv6Payload(firstFragmentSent), - ), - ) - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectOutputDropped int - expectPostroutingDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectOutputDropped: 0, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectOutputDropped: 0, - expectPostroutingDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all with Output chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectOutputDropped: nPackets, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Drop all with Postrouting chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.NATID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectOutputDropped: 0, - expectPostroutingDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some with Output chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectOutputDropped: 1, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Drop some with Postrouting chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Postrouting DROP rule that matches only 1 - // of the 3 packets. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.NATID, true /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectOutputDropped: 0, - expectPostroutingDropped: 1, - expectWritten: nPackets, - }, - } - - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { - t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) - } - if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { - t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) - } - if nWritten != test.expectWritten { - t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) - } - { - mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - transHdrLen int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: header.IPv6MinimumMTU, - transHdrLen: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: header.IPv6MinimumMTU, - transHdrLen: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 776, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv6MinimumMTU + 1, - transHdrLen: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 776, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - transHdrLen: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: header.IPv6MinimumMTU, - transHdrLen: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 76, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ttl = 42 - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - source := pkt.Clone() - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("WritePacket(_, _, _): = %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - tests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - source := pkt - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if n != wantTotalPackets || err != nil { - t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transHdrLen int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transHdrLen: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 1300, - payloadSize: 3000, - transHdrLen: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 1500, - payloadSize: 4000, - transHdrLen: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than transport header", - mtu: header.IPv6MinimumMTU, - transHdrLen: 1500, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrMessageTooLong{}, - }, - { - description: "Error when MTU is smaller than IPv6 minimum MTU", - mtu: header.IPv6MinimumMTU - 1, - transHdrLen: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestForwarding(t *testing.T) { - const ( - incomingNICID = 1 - outgoingNICID = 2 - randomSequence = 123 - randomIdent = 42 - ) - - incomingIPv6Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10::1").To16()), - PrefixLen: 64, - } - outgoingIPv6Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11::1").To16()), - PrefixLen: 64, - } - multicastIPv6Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("ff00::").To16()), - PrefixLen: 64, - } - - remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) - remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) - unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) - linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16()) - - tests := []struct { - name string - extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) - TTL uint8 - expectErrorICMP bool - expectPacketForwarded bool - payloadLength int - countUnrouteablePackets uint64 - sourceAddr tcpip.Address - destAddr tcpip.Address - icmpType header.ICMPv6Type - icmpCode header.ICMPv6Code - expectPacketUnrouteableError bool - expectLinkLocalSourceError bool - expectLinkLocalDestError bool - expectExtensionHeaderError bool - }{ - { - name: "TTL of zero", - TTL: 0, - expectErrorICMP: true, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - icmpType: header.ICMPv6TimeExceeded, - icmpCode: header.ICMPv6HopLimitExceeded, - }, - { - name: "TTL of one", - TTL: 1, - expectErrorICMP: true, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - icmpType: header.ICMPv6TimeExceeded, - icmpCode: header.ICMPv6HopLimitExceeded, - }, - { - name: "TTL of two", - TTL: 2, - expectPacketForwarded: true, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - }, - { - name: "TTL of three", - TTL: 3, - expectPacketForwarded: true, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - expectPacketForwarded: true, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - }, - { - name: "Network unreachable", - TTL: 2, - expectErrorICMP: true, - sourceAddr: remoteIPv6Addr1, - destAddr: unreachableIPv6Addr, - icmpType: header.ICMPv6DstUnreachable, - icmpCode: header.ICMPv6NetworkUnreachable, - expectPacketUnrouteableError: true, - }, - { - name: "Multicast destination", - TTL: 2, - countUnrouteablePackets: 1, - sourceAddr: remoteIPv6Addr1, - destAddr: multicastIPv6Addr.Address, - expectPacketForwarded: true, - }, - { - name: "Link local destination", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: linkLocalIPv6Addr, - expectLinkLocalDestError: true, - }, - { - name: "Link local source", - TTL: 2, - sourceAddr: linkLocalIPv6Addr, - destAddr: remoteIPv6Addr2, - expectLinkLocalSourceError: true, - }, - { - name: "Hopbyhop with unknown option skippable action", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption())) - }, - expectPacketForwarded: true, - }, - { - name: "Hopbyhop with unknown option discard action", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID, nil - }, - expectExtensionHeaderError: true, - }, - { - name: "Hopbyhop with unknown option discard and send icmp action (unicast)", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID, nil - }, - expectErrorICMP: true, - icmpType: header.ICMPv6ParamProblem, - icmpCode: header.ICMPv6UnknownOption, - expectExtensionHeaderError: true, - }, - { - name: "Hopbyhop with unknown option discard and send icmp action (multicast)", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: multicastIPv6Addr.Address, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID, nil - }, - expectErrorICMP: true, - icmpType: header.ICMPv6ParamProblem, - icmpCode: header.ICMPv6UnknownOption, - expectExtensionHeaderError: true, - }, - { - name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID, nil - }, - expectErrorICMP: true, - icmpType: header.ICMPv6ParamProblem, - icmpCode: header.ICMPv6UnknownOption, - expectExtensionHeaderError: true, - }, - { - name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: multicastIPv6Addr.Address, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID, nil - }, - expectExtensionHeaderError: true, - }, - { - name: "Hopbyhop with router alert option", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 0, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD))) - }, - expectPacketForwarded: true, - }, - { - name: "Hopbyhop with two router alert options", - TTL: 2, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { - return []byte{ - nextHdr, 1, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - }, hopByHopExtHdrID, nil - }, - expectExtensionHeaderError: true, - }, - { - name: "Can't fragment", - TTL: 2, - payloadLength: header.IPv6MinimumMTU + 1, - expectErrorICMP: true, - sourceAddr: remoteIPv6Addr1, - destAddr: remoteIPv6Addr2, - icmpType: header.ICMPv6PacketTooBig, - icmpCode: header.ICMPv6UnusedCode, - }, - { - name: "Can't fragment multicast", - TTL: 2, - payloadLength: header.IPv6MinimumMTU + 1, - sourceAddr: remoteIPv6Addr1, - destAddr: multicastIPv6Addr.Address, - expectErrorICMP: true, - icmpType: header.ICMPv6PacketTooBig, - icmpCode: header.ICMPv6UnusedCode, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - // We expect at most a single packet in response to our ICMP Echo Request. - incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) - } - incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) - } - - outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) - } - outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: incomingIPv6Addr.Subnet(), - NIC: incomingNICID, - }, - { - Destination: outgoingIPv6Addr.Subnet(), - NIC: outgoingNICID, - }, - { - Destination: multicastIPv6Addr.Subnet(), - NIC: outgoingNICID, - }, - }) - - if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) - } - - transportProtocol := header.ICMPv6ProtocolNumber - var extHdrBytes []byte - extHdrChecker := checker.IPv6ExtHdr() - if test.extHdr != nil { - nextHdrID := hopByHopExtHdrID - extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber)) - transportProtocol = tcpip.TransportProtocolNumber(nextHdrID) - } - extHdrLen := len(extHdrBytes) - - ipHeaderLength := header.IPv6MinimumSize - icmpHeaderLength := header.ICMPv6MinimumSize - totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen - hdr := buffer.NewPrependable(totalLength) - hdr.Prepend(test.payloadLength) - icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) - - icmpH.SetIdent(randomIdent) - icmpH.SetSequence(randomSequence) - icmpH.SetType(header.ICMPv6EchoRequest) - icmpH.SetCode(header.ICMPv6UnusedCode) - icmpH.SetChecksum(0) - icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpH, - Src: test.sourceAddr, - Dst: test.destAddr, - })) - copy(hdr.Prepend(extHdrLen), extHdrBytes) - ip := header.IPv6(hdr.Prepend(ipHeaderLength)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), - TransportProtocol: transportProtocol, - HopLimit: test.TTL, - SrcAddr: test.sourceAddr, - DstAddr: test.destAddr, - }) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt) - - reply, ok := incomingEndpoint.Read() - - if test.expectErrorICMP { - if !ok { - t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) - } - - // As per RFC 4443, page 9: - // - // The returned ICMP packet will contain as much of invoking packet - // as possible without the ICMPv6 packet exceeding the minimum IPv6 - // MTU. - expectedICMPPayloadLength := func() int { - maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength - if len(hdr.View()) > maxICMPPayloadLength { - return maxICMPPayloadLength - } - return len(hdr.View()) - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(incomingIPv6Addr.Address), - checker.DstAddr(test.sourceAddr), - checker.TTL(DefaultTTL), - checker.ICMPv6( - checker.ICMPv6Type(test.icmpType), - checker.ICMPv6Code(test.icmpCode), - checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]), - ), - ) - - if n := outgoingEndpoint.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) - } - } else if ok { - t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) - } - - reply, ok = outgoingEndpoint.Read() - if test.expectPacketForwarded { - if !ok { - t.Fatal("expected ICMP Echo Request packet through outgoing NIC") - } - - checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(test.sourceAddr), - checker.DstAddr(test.destAddr), - checker.TTL(test.TTL-1), - extHdrChecker, - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest), - checker.ICMPv6Code(header.ICMPv6UnusedCode), - checker.ICMPv6Payload(nil), - ), - ) - - if n := incomingEndpoint.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } - } else if ok { - t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) - } - - boolToInt := func(val bool) uint64 { - if val { - return 1 - } - return 0 - } - - if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want { - t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { - t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want { - t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) - } - }) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // supposed to be bound. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go deleted file mode 100644 index bc9cf6999..000000000 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ /dev/null @@ -1,603 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -var ( - linkLocalAddr = testutil.MustParse6("fe80::1") - globalAddr = testutil.MustParse6("a80::1") - globalMulticastAddr = testutil.MustParse6("ff05:100::2") - - linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) - globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) -) - -func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { - t.Helper() - - checker.IPv6WithExtHdr(t, p, - checker.IPv6ExtHdr( - checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), - ), - checker.SrcAddr(localAddress), - checker.DstAddr(remoteAddress), - checker.TTL(header.MLDHopLimit), - checker.MLD(mldType, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(groupAddress), - ), - ) -} - -func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // The stack will join an address's solicited node multicast address when - // an address is added. An MLD report message should be sent for the - // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - - // The stack will leave an address's solicited node multicast address when - // an address is removed. An MLD done message should be sent for the - // solicited-node group. - if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a done message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) - } -} - -func TestSendQueuedMLDReports(t *testing.T) { - const ( - nicID = 1 - maxReports = 2 - ) - - tests := []struct { - name string - dadTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD Disabled", - dadTransmits: 0, - retransmitTimer: 0, - }, - { - name: "DAD Enabled", - dadTransmits: 1, - retransmitTimer: time.Second, - }, - } - - nonce := [...]byte{ - 1, 2, 3, 4, 5, 6, - } - - const maxNSMessages = 2 - secureRNGBytes := make([]byte, len(nonce)*maxNSMessages) - for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] { - if n := copy(b, nonce[:]); n != len(nonce) { - t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce)) - } - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) - clock := faketime.NewManualClock() - var secureRNG bytes.Reader - secureRNG.Reset(secureRNGBytes[:]) - s := stack.New(stack.Options{ - SecureRNG: &secureRNG, - RandSource: rand.NewSource(time.Now().UnixNano()), - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dadTransmits, - RetransmitTimer: test.retransmitTimer, - }, - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - Clock: clock, - }) - - // Allow space for an extra packet so we can observe packets that were - // unexpectedly sent. - e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - resolveDAD := func(addr, snmc tcpip.Address) { - clock.Advance(dadResolutionTime) - if p, ok := e.Read(); !ok { - t.Fatal("expected DAD packet") - } else { - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr), - checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}), - )) - } - } - - var reportCounter uint64 - reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - var doneCounter uint64 - doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - if got := doneStat.Value(); got != doneCounter { - t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) - } - - // Joining a group without an assigned address should send an MLD report - // with the unspecified address. - if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) - } - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", globalMulticastAddr) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Adding a global address should not send reports for the already joined - // group since we should only send queued reports when a link-local - // address is assigned. - // - // Note, we will still expect to send a report for the global address's - // solicited node address from the unspecified address as per RFC 3590 - // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) - } - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", globalAddrSNMC) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) - } - if dadResolutionTime != 0 { - // Reports should not be sent when the address resolves. - resolveDAD(globalAddr, globalAddrSNMC) - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - } - // Leave the group since we don't care about the global address's - // solicited node multicast group membership. - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) - } - if got := doneStat.Value(); got != doneCounter { - t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) - } - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Adding a link-local address should send a report for its solicited node - // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) - } - if dadResolutionTime != 0 { - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - resolveDAD(linkLocalAddr, linkLocalAddrSNMC) - } - - // We expect two batches of reports to be sent (1 batch when the - // link-local address is assigned, and another after the maximum - // unsolicited report interval. - for i := 0; i < 2; i++ { - // We expect reports to be sent (one for globalMulticastAddr and another - // for linkLocalAddrSNMC). - reportCounter += maxReports - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - - addrs := map[tcpip.Address]bool{ - globalMulticastAddr: false, - linkLocalAddrSNMC: false, - } - for range addrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) - } - - addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() - if seen, ok := addrs[addr]; !ok { - t.Fatalf("got unexpected packet destined to %s", addr) - } else if seen { - t.Fatalf("got another packet destined to %s", addr) - } - - addrs[addr] = true - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) - - clock.Advance(ipv6.UnsolicitedReportIntervalMax) - } - } - - // Should not send any more reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - }) - } -} - -// createAndInjectMLDPacket creates and injects an MLD packet with the -// specified fields. -func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, hopLimit uint8, srcAddress tcpip.Address, withRouterAlertOption bool, routerAlertValue header.IPv6RouterAlertValue) { - var extensionHeaders header.IPv6ExtHdrSerializer - if withRouterAlertOption { - extensionHeaders = header.IPv6ExtHdrSerializer{ - header.IPv6SerializableHopByHopExtHdr{ - &header.IPv6RouterAlertOption{Value: routerAlertValue}, - }, - } - } - - extensionHeadersLength := extensionHeaders.Length() - payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize - buf := buffer.NewView(header.IPv6MinimumSize + payloadLength) - - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - HopLimit: hopLimit, - TransportProtocol: header.ICMPv6ProtocolNumber, - SrcAddr: srcAddress, - DstAddr: header.IPv6AllNodesMulticastAddress, - ExtensionHeaders: extensionHeaders, - }) - - icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) - icmp.SetType(mldType) - mld := header.MLD(icmp.MessageBody()) - mld.SetMaximumResponseDelay(0) - mld.SetMulticastAddress(header.IPv6Any) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: srcAddress, - Dst: header.IPv6AllNodesMulticastAddress, - })) - - e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -func TestMLDPacketValidation(t *testing.T) { - const nicID = 1 - linkLocalAddr2 := testutil.MustParse6("fe80::2") - - tests := []struct { - name string - messageType header.ICMPv6Type - srcAddr tcpip.Address - includeRouterAlertOption bool - routerAlertValue header.IPv6RouterAlertValue - hopLimit uint8 - expectValidMLD bool - getMessageTypeStatValue func(tcpip.Stats) uint64 - }{ - { - name: "valid", - messageType: header.ICMPv6MulticastListenerQuery, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertMLD, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit, - expectValidMLD: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerQuery.Value() }, - }, - { - name: "bad hop limit", - messageType: header.ICMPv6MulticastListenerReport, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertMLD, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit + 1, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() }, - }, - { - name: "src ip not link local", - messageType: header.ICMPv6MulticastListenerReport, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertMLD, - srcAddr: globalAddr, - hopLimit: header.MLDHopLimit, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() }, - }, - { - name: "missing router alert ip option", - messageType: header.ICMPv6MulticastListenerDone, - includeRouterAlertOption: false, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() }, - }, - { - name: "incorrect router alert value", - messageType: header.ICMPv6MulticastListenerDone, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertRSVP, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - }) - e := channel.New(nicID, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - stats := s.Stats() - // Verify that every relevant stats is zero'd before we send a packet. - if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) - } - if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != 0 { - t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = 0", got) - } - if got := stats.IP.PacketsDelivered.Value(); got != 0 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) - } - createAndInjectMLDPacket(e, test.messageType, test.hopLimit, test.srcAddr, test.includeRouterAlertOption, test.routerAlertValue) - // We always expect the packet to pass IP validation. - if got := stats.IP.PacketsDelivered.Value(); got != 1 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) - } - // Even when the MLD-specific validation checks fail, we expect the - // corresponding MLD counter to be incremented. - if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) - } - var expectedInvalidCount uint64 - if !test.expectValidMLD { - expectedInvalidCount = 1 - } - if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { - t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) - } - }) - } -} - -func TestMLDSkipProtocol(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - group tcpip.Address - expectReport bool - }{ - { - name: "Reserverd0", - group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: false, - }, - { - name: "Interface Local", - group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: false, - }, - { - name: "Link Local", - group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Realm Local", - group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Admin Local", - group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Site Local", - group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(6)", - group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(7)", - group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Organization Local", - group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(9)", - group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(A)", - group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(B)", - group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(C)", - group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Unassigned(D)", - group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "Global", - group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - { - name: "ReservedF", - group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", - expectReport: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - - if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil { - t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err) - } - if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group) - } - - if !test.expectReport { - if p, ok := e.Read(); ok { - t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p) - } - - return - } - - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go deleted file mode 100644 index f0186c64e..000000000 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ /dev/null @@ -1,1341 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "math/rand" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -var _ NDPDispatcher = (*testNDPDispatcher)(nil) - -// testNDPDispatcher is an NDPDispatcher only allows default router discovery. -type testNDPDispatcher struct { - addr tcpip.Address -} - -func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { -} - -func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) { - t.addr = addr -} - -func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) { - t.addr = addr -} - -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { -} - -func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { -} - -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { -} - -func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { -} - -func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { -} - -func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { - var ndpDisp testNDPDispatcher - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) - } - - ipv6EP := ep.(*endpoint) - ipv6EP.mu.Lock() - ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference) - ipv6EP.mu.Unlock() - - if ndpDisp.addr != lladdr1 { - t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) - } - - ndpDisp.addr = "" - ndpEP := ep.(stack.NDPEndpoint) - ndpEP.InvalidateDefaultRouter(lladdr1) - if ndpDisp.addr != lladdr1 { - t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) - } -} - -// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a -// valid NDP NS message with the Source Link Layer Address option results in a -// new entry in the link address cache for the sender of the message. -func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: lladdr1, - Dst: lladdr0, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - neighbors, err := s.Neighbors(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) - } - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) - } - neighborByAddr[n.Addr] = n - } - - if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 { - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - if !ok { - t.Fatalf("expected a neighbor entry for %q", lladdr1) - } - if neigh.LinkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr) - } - if neigh.State != stack.Stale { - t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale) - } - } else { - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - - if ok { - t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) - } - } - }) - } -} - -func TestNeighborSolicitationResponse(t *testing.T) { - const nicID = 1 - nicAddr := lladdr0 - remoteAddr := lladdr1 - nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) - nicLinkAddr := linkAddr0 - remoteLinkAddr0 := linkAddr1 - remoteLinkAddr1 := linkAddr2 - - tests := []struct { - name string - nsOpts header.NDPOptionsSerializer - nsSrcLinkAddr tcpip.LinkAddress - nsSrc tcpip.Address - nsDst tcpip.Address - nsInvalid bool - naDstLinkAddr tcpip.LinkAddress - naSolicited bool - naSrc tcpip.Address - naDst tcpip.Address - performsLinkResolution bool - }{ - { - name: "Unspecified source to solicited-node multicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, - }, - { - name: "Unspecified source with source ll option to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - { - name: "Unspecified source to unicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddr, - nsInvalid: true, - }, - { - name: "Unspecified source with source ll option to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddr, - nsInvalid: true, - }, - { - name: "Specified source with 1 source ll to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 1 source ll different from route to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr1, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source to multicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - { - name: "Specified source with 2 source ll to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - - { - name: "Specified source to unicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - // Since we send a unicast solicitations to a node without an entry for - // the remote, the node needs to perform neighbor discovery to get the - // remote's link address to send the advertisement response. - performsLinkResolution: true, - }, - { - name: "Specified source with 1 source ll to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 1 source ll different from route to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr1, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 2 source ll to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - Clock: clock, - }) - e := channel.New(1, 1280, nicLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: test.nsSrc, - Dst: test.nsDst, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } - - // If we expected the NS to be invalid, we have nothing else to check. - return - } - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - if test.performsLinkResolution { - clock.RunImmediatelyScheduledJobs() - p, got := e.Read() - if !got { - t.Fatal("expected an NDP NS response") - } - - respNSDst := header.SolicitedNodeAddr(test.nsSrc) - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) - if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(nicAddr), - checker.DstAddr(respNSDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(test.nsSrc), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(nicLinkAddr), - }), - )) - - ser := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(test.nsSrc) - na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: test.nsSrc, - Dst: nicAddr, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, - }) - e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - clock.RunImmediatelyScheduledJobs() - p, got := e.Read() - if !got { - t.Fatal("expected an NDP NA response") - } - - if p.Route.LocalAddress != test.naSrc { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } - if p.Route.RemoteAddress != test.naDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) - } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) - }) - } -} - -// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option does not -// result in a new entry in the neighbor cache for the target of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - isValid bool - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - isValid: true, - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - { - name: "Multiple", - optsBuf: []byte{ - 2, 1, 2, 3, 4, 5, 6, 7, - 2, 1, 2, 3, 4, 5, 6, 8, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.MessageBody()) - ns.SetTargetAddress(lladdr1) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: lladdr1, - Dst: lladdr0, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - neighbors, err := s.Neighbors(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) - } - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) - } - neighborByAddr[n.Addr] = n - } - - if neigh, ok := neighborByAddr[lladdr1]; ok { - t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) - } - - if test.isValid { - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -func TestNDPValidation(t *testing.T) { - const nicID = 1 - - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - var extHdrs header.IPv6ExtHdrSerializer - if atomicFragment { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) - } - extHdrsLen := extHdrs.Length() - - ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + extHdrsLen), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - ExtensionHeaders: extHdrs, - }) - vv := ip.ToVectorisedView() - vv.AppendView(payload) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - var sllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - extraData: sllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool - }{ - { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, - }, - } - - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) - if err != nil { - t.Fatal(err) - } - - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } - - t.Run(name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if isRouter { - if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) - } - } - - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatal("cannot find network endpoint instance for IPv6") - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmpH[typ.size:], typ.extraData) - icmpH.SetType(typ.typ) - icmpH.SetCode(test.code) - icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpH[:typ.size], - Src: lladdr0, - Dst: lladdr1, - PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), - PayloadLen: len(typ.extraData), - })) - - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid.Value() = %d, want = 0", got) - } - - // Should initially not have dropped any packets. - if got := routerOnly.Value(); got != 0 { - t.Errorf("got routerOnly.Value() = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep) - - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } - - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid.Value() = %d, want = %d", got, want) - } - - want = 0 - if test.valid && !isRouter && typ.routerOnly { - // Router only packets are expected to be dropped when operating - // as a host. - want = 1 - } - if got := routerOnly.Value(); got != want { - t.Errorf("got routerOnly.Value() = %d, want = %d", got, want) - } - }) - } - }) - } - } -} - -// TestNeighborAdvertisementValidation tests that the NIC validates received -// Neighbor Advertisements. -// -// In particular, if the IP Destination Address is a multicast address, and the -// Solicited flag is not zero, the Neighbor Advertisement is invalid and should -// be discarded. -func TestNeighborAdvertisementValidation(t *testing.T) { - tests := []struct { - name string - ipDstAddr tcpip.Address - solicitedFlag bool - valid bool - }{ - { - name: "Multicast IP destination address with Solicited flag set", - ipDstAddr: header.IPv6AllNodesMulticastAddress, - solicitedFlag: true, - valid: false, - }, - { - name: "Multicast IP destination address with Solicited flag unset", - ipDstAddr: header.IPv6AllNodesMulticastAddress, - solicitedFlag: false, - valid: true, - }, - { - name: "Unicast IP destination address with Solicited flag set", - ipDstAddr: lladdr0, - solicitedFlag: true, - valid: true, - }, - { - name: "Unicast IP destination address with Solicited flag unset", - ipDstAddr: lladdr0, - solicitedFlag: false, - valid: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetTargetAddress(lladdr1) - na.SetSolicitedFlag(test.solicitedFlag) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: lladdr1, - Dst: test.ipDstAddr, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: test.ipDstAddr, - }) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxNA := stats.NeighborAdvert - - if got := rxNA.Value(); got != 0 { - t.Fatalf("got rxNA = %d, want = 0", got) - } - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if got := rxNA.Value(); got != 1 { - t.Fatalf("got rxNA = %d, want = 1", got) - } - var wantInvalid uint64 = 1 - if test.valid { - wantInvalid = 0 - } - if got := invalid.Value(); got != wantInvalid { - t.Fatalf("got invalid = %d, want = %d", got, wantInvalid) - } - // As per RFC 4861 section 7.2.5: - // When a valid Neighbor Advertisement is received ... - // If no entry exists, the advertisement SHOULD be silently discarded. - // There is no need to create an entry if none exists, since the - // recipient has apparently not initiated any communication with the - // target. - if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) - } - }) - } -} - -// TestRouterAdvertValidation tests that when the NIC is configured to handle -// NDP Router Advertisement packets, it validates the Router Advertisement -// properly before handling them. -func TestRouterAdvertValidation(t *testing.T) { - tests := []struct { - name string - src tcpip.Address - hopLimit uint8 - code header.ICMPv6Code - ndpPayload []byte - expectedSuccess bool - }{ - { - "OK", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - true, - }, - { - "NonLinkLocalSourceAddr", - addr1, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "HopLimitNot255", - lladdr0, - 254, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NonZeroCode", - lladdr0, - 255, - 1, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NDPPayloadTooSmall", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, - }, - false, - }, - { - "OKWithOptions", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - 2, 1, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - true, - }, - { - "OptionWithZeroLength", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - // Invalid as it has 0 length. - 2, 0, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.MessageBody(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: test.src, - Dst: header.IPv6AllNodesMulticastAddress, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } - - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD -// performed when adding new addresses do not interfere with each other. -func TestCheckDuplicateAddress(t *testing.T) { - const nicID = 1 - - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - - nonces := [...][]byte{ - {1, 2, 3, 4, 5, 6}, - {7, 8, 9, 10, 11, 12}, - } - - var secureRNGBytes []byte - for _, n := range nonces { - secureRNGBytes = append(secureRNGBytes, n...) - } - var secureRNG bytes.Reader - secureRNG.Reset(secureRNGBytes[:]) - s := stack.New(stack.Options{ - Clock: clock, - RandSource: rand.NewSource(time.Now().UnixNano()), - SecureRNG: &secureRNG, - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ - DADConfigs: dadConfigs, - })}, - }) - // This test is expected to send at max 2 DAD messages. We allow an extra - // packet to be stored to catch unexpected packets. - e := channel.New(3, header.IPv6MinimumMTU, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - dadPacketsSent := 0 - snmc := header.SolicitedNodeAddr(lladdr0) - remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) - checkDADMsg := func() { - clock.RunImmediatelyScheduledJobs() - p, ok := e.Read() - if !ok { - t.Fatalf("expected %d-th DAD message", dadPacketsSent) - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber) - } - - if p.Route.RemoteLinkAddress != remoteLinkAddr { - t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), - )) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - checkDADMsg() - - // Start DAD for the address we just added. - // - // Even though the stack will perform DAD before the added address transitions - // from tentative to assigned, this DAD request should be independent of that. - ch := make(chan stack.DADResult, 3) - dadRequestsMade := 1 - dadPacketsSent++ - if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) - } else if res != stack.DADStarting { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) - } - checkDADMsg() - - // Remove the address and make sure our DAD request was not stopped. - if err := s.RemoveAddress(nicID, lladdr0); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) - } - // Should not restart DAD since we already requested DAD above - the handler - // should be called when the original request completes so we should not send - // an extra DAD message here. - dadRequestsMade++ - if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) - } else if res != stack.DADAlreadyRunning { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning) - } - - // Wait for DAD to complete. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - for i := 0; i < dadRequestsMade; i++ { - if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" { - t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) - } - } - // Should have no more results. - select { - case r := <-ch: - t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - - // Should have no more packets. - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } -} diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go deleted file mode 100644 index 1b96b1fb8..000000000 --- a/pkg/tcpip/network/multicast_group_test.go +++ /dev/null @@ -1,1278 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "fmt" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - - defaultIPv4PrefixLength = 24 - - igmpMembershipQuery = uint8(header.IGMPMembershipQuery) - igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) - igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) - igmpLeaveGroup = uint8(header.IGMPLeaveGroup) - mldQuery = uint8(header.ICMPv6MulticastListenerQuery) - mldReport = uint8(header.ICMPv6MulticastListenerReport) - mldDone = uint8(header.ICMPv6MulticastListenerDone) - - maxUnsolicitedReports = 2 -) - -var ( - stackIPv4Addr = testutil.MustParse4("10.0.0.1") - linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1") - linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2") - - ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3") - ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4") - ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5") - ipv6MulticastAddr1 = testutil.MustParse6("ff02::3") - ipv6MulticastAddr2 = testutil.MustParse6("ff02::4") - ipv6MulticastAddr3 = testutil.MustParse6("ff02::5") -) - -var ( - // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the - // NIC will wait before sending an unsolicited report after joining a - // multicast group, in deciseconds. - unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { - const decisecond = time.Second / 10 - if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { - panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) - } - return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) - }() - - ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1) -) - -// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet -// sent to the provided address with the passed fields set. -func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv6WithExtHdr(t, payload, - checker.IPv6ExtHdr( - checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), - ), - checker.SrcAddr(linkLocalIPv6Addr1), - checker.DstAddr(remoteAddress), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, - checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), - checker.MLDMulticastAddress(groupAddress), - ), - ) -} - -// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. -func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(stackIPv4Addr), - checker.DstAddr(remoteAddress), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(header.IGMPType(igmpType)), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) - s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) - return e, s, clock -} - -func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { - t.Helper() - - igmpEnabled := v4 && mgpEnabled - mldEnabled := !v4 && mgpEnabled - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: mldEnabled, - }, - }), - }, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, - } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) - } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) - } - - return s, clock -} - -// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join -// when it is created with an IPv6 address. -// -// To not interfere with tests, checkInitialIPv6Groups will leave the added -// address's solicited node multicast group so that the tests can all assume -// the NIC has not joined any IPv6 groups. -func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { - t.Helper() - - stats := s.Stats().ICMP.V6.PacketsSent - - reportCounter++ - if got := stats.MulticastListenerReport.Value(); got != reportCounter { - t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) - } - - // Leave the group to not affect the tests. This is fine since we are not - // testing DAD or the solicited node address specifically. - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) - } - leaveCounter++ - if got := stats.MulticastListenerDone.Value(); got != leaveCounter { - t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - - return reportCounter, leaveCounter -} - -// createAndInjectIGMPPacket creates and injects an IGMP packet with the -// specified fields. -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { - options := header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: header.IGMPTTL, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: remoteIPv4Addr, - DstAddr: header.IPv4AllSystems, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(header.IGMPType(igmpType)) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// createAndInjectMLDPacket creates and injects an MLD packet with the -// specified fields. -func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { - extensionHeaders := header.IPv6ExtHdrSerializer{ - header.IPv6SerializableHopByHopExtHdr{ - &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, - }, - } - - extensionHeadersLength := extensionHeaders.Length() - payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize - buf := buffer.NewView(header.IPv6MinimumSize + payloadLength) - - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - HopLimit: header.MLDHopLimit, - TransportProtocol: header.ICMPv6ProtocolNumber, - SrcAddr: linkLocalIPv6Addr2, - DstAddr: header.IPv6AllNodesMulticastAddress, - ExtensionHeaders: extensionHeaders, - }) - - icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) - icmp.SetType(header.ICMPv6Type(mldType)) - mld := header.MLD(icmp.MessageBody()) - mld.SetMaximumResponseDelay(uint16(maxRespDelay)) - mld.SetMulticastAddress(groupAddress) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: linkLocalIPv6Addr2, - Dst: header.IPv6AllNodesMulticastAddress, - })) - - e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// TestMGPDisabled tests that the multicast group protocol is not enabled by -// default. -func TestMGPDisabled(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - rxQuery func(*channel.Endpoint) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxQuery: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxQuery: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) - - // This NIC may join multicast groups when it is enabled but since MGP is - // disabled, no reports should be sent. - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) - } - - // Inject a general query message. This should only trigger a report to be - // sent if the MGP was enabled. - test.rxQuery(e) - if got := test.receivedQueryStat(s).Value(); got != 1 { - t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) - } - }) - } -} - -func TestMGPReceiveCounters(t *testing.T) { - tests := []struct { - name string - headerType uint8 - maxRespTime byte - groupAddress tcpip.Address - statCounter func(*stack.Stack) *tcpip.StatCounter - rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) - }{ - { - name: "IGMP Membership Query", - headerType: igmpMembershipQuery, - maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, - groupAddress: header.IPv4Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMPv1 Membership Report", - headerType: igmpv1MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.V1MembershipReport - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMPv2 Membership Report", - headerType: igmpv2MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.V2MembershipReport - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMP Leave Group", - headerType: igmpLeaveGroup, - maxRespTime: 0, - groupAddress: header.IPv4AllRoutersGroup, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.LeaveGroup - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "MLD Query", - headerType: mldQuery, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxMGPkt: createAndInjectMLDPacket, - }, - { - name: "MLD Report", - headerType: mldReport, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport - }, - rxMGPkt: createAndInjectMLDPacket, - }, - { - name: "MLD Done", - headerType: mldDone, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone - }, - rxMGPkt: createAndInjectMLDPacket, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) - - test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) - if got := test.statCounter(s).Value(); got != 1 { - t.Fatalf("got %s received = %d, want = 1", test.name, got) - } - }) - } -} - -// TestMGPJoinGroup tests that when explicitly joining a multicast group, the -// stack schedules and sends correct Membership Reports. -func TestMGPJoinGroup(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, _ = test.checkInitialGroups(t, e, s, clock) - } - - // Test joining a specific address explicitly and verify a Report is sent - // immediately. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - reportCounter++ - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Verify the second report is sent by the maximum unsolicited response - // interval. - p, ok := e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) - } - clock.Advance(test.maxUnsolicitedResponseDelay) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPLeaveGroup tests that when leaving a previously joined multicast -// group the stack sends a leave/done message. -func TestMGPLeaveGroup(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - reportCounter++ - if got := test.sentReportStat(s).Value(); got != reportCounter { - t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Leaving the group should trigger an leave/done message to be sent. - if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) - } - leaveCounter++ - if got := test.sentLeaveStat(s).Value(); got != leaveCounter { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a leave message to be sent") - } else { - test.validateLeave(t, p) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPQueryMessages tests that a report is sent in response to query -// messages. -func TestMGPQueryMessages(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - rxQuery func(*channel.Endpoint, uint8, tcpip.Address) - validateReport func(*testing.T, channel.PacketInfo) - maxRespTimeToDuration func(uint8) time.Duration - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { - createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - maxRespTimeToDuration: header.DecisecondToDuration, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { - createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - maxRespTimeToDuration: func(d uint8) time.Duration { - return time.Duration(d) * time.Millisecond - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - subTests := []struct { - name string - multicastAddr tcpip.Address - expectReport bool - }{ - { - name: "Unspecified", - multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), - expectReport: true, - }, - { - name: "Specified", - multicastAddr: test.multicastAddr, - expectReport: true, - }, - { - name: "Specified other address", - multicastAddr: func() tcpip.Address { - addrBytes := []byte(test.multicastAddr) - addrBytes[len(addrBytes)-1]++ - return tcpip.Address(addrBytes) - }(), - expectReport: false, - }, - } - - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, _ = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - sentReportStat := test.sentReportStat(s) - for i := 0; i < maxUnsolicitedReports; i++ { - sentReportStat := test.sentReportStat(s) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatalf("expected %d-th report message to be sent", i) - } else { - test.validateReport(t, p) - } - clock.Advance(test.maxUnsolicitedResponseDelay) - } - if t.Failed() { - t.FailNow() - } - - // Should not send any more packets until a query. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - - // Receive a query message which should trigger a report to be sent at - // some time before the maximum response time if the report is - // targeted at the host. - const maxRespTime = 100 - test.rxQuery(e, maxRespTime, subTest.multicastAddr) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p.Pkt) - } - - if subTest.expectReport { - clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } - }) - } -} - -// TestMGPQueryMessages tests that no further reports or leave/done messages -// are sent after receiving a report. -func TestMGPReportMessages(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - rxReport func(*channel.Endpoint) - validateReport func(*testing.T, channel.PacketInfo) - maxRespTimeToDuration func(uint8) time.Duration - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - rxReport: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - maxRespTimeToDuration: header.DecisecondToDuration, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - rxReport: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - maxRespTimeToDuration: func(d uint8) time.Duration { - return time.Duration(d) * time.Millisecond - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - sentReportStat := test.sentReportStat(s) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Receiving a report for a group we joined should cancel any further - // reports. - test.rxReport(e) - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); ok { - t.Errorf("sent unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Leaving a group after getting a report should not send a leave/done - // message. - if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) - } - clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != leaveCounter { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -func TestMGPWithNICLifecycle(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddrs []tcpip.Address - finalMulticastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) - validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) - getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, - finalMulticastAddr: ipv4MulticastAddr3, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) - }, - getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { - t.Helper() - - ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { - t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) - } - addr := header.IGMP(ipv4.Payload()).GroupAddress() - s, ok := seen[addr] - if !ok { - t.Fatalf("unexpectedly got a packet for group %s", addr) - } - if s { - t.Fatalf("already saw packet for group %s", addr) - } - seen[addr] = true - return addr - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, - finalMulticastAddr: ipv6MulticastAddr3, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateMLDPacket(t, p, addr, mldReport, 0, addr) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) - }, - getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { - t.Helper() - - ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - - ipv6HeaderIter := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), - buffer.View(ipv6.Payload()).ToVectorisedView(), - ) - - var transport header.IPv6RawPayloadHeader - for { - h, done, err := ipv6HeaderIter.Next() - if err != nil { - t.Fatalf("ipv6HeaderIter.Next(): %s", err) - } - if done { - t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) - } - if t, ok := h.(header.IPv6RawPayloadHeader); ok { - transport = t - break - } - } - - if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { - t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) - } - icmpv6 := header.ICMPv6(transport.Buf.ToView()) - if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { - t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) - } - addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() - s, ok := seen[addr] - if !ok { - t.Fatalf("unexpectedly got a packet for group %s", addr) - } - if s { - t.Fatalf("already saw packet for group %s", addr) - } - seen[addr] = true - return addr - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - sentReportStat := test.sentReportStat(s) - for _, a := range test.multicastAddrs { - if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) - } - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatalf("expected a report message to be sent for %s", a) - } else { - test.validateReport(t, p, a) - } - } - if t.Failed() { - t.FailNow() - } - - // Leave messages should be sent for the joined groups when the NIC is - // disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - sentLeaveStat := test.sentLeaveStat(s) - leaveCounter += uint64(len(test.multicastAddrs)) - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - { - seen := make(map[tcpip.Address]bool) - for _, a := range test.multicastAddrs { - seen[a] = false - } - - for i := range test.multicastAddrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected (%d-th) leave message to be sent", i) - } - - test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) - } - } - if t.Failed() { - t.FailNow() - } - - // Reports should be sent for the joined groups when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - reportCounter += uint64(len(test.multicastAddrs)) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - { - seen := make(map[tcpip.Address]bool) - for _, a := range test.multicastAddrs { - seen[a] = false - } - - for i := range test.multicastAddrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected (%d-th) report message to be sent", i) - } - - test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) - } - } - if t.Failed() { - t.FailNow() - } - - // Joining/leaving a group while disabled should not send any messages. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - leaveCounter += uint64(len(test.multicastAddrs)) - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - for i := range test.multicastAddrs { - if _, ok := e.Read(); !ok { - t.Fatalf("expected (%d-th) leave message to be sent", i) - } - } - for _, a := range test.multicastAddrs { - if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) - } - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); ok { - t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) - } - } - if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) - } - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); ok { - t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) - } - - // A report should only be sent for the group we last joined after - // enabling the NIC since the original groups were all left. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p, test.finalMulticastAddr) - } - - clock.Advance(test.maxUnsolicitedResponseDelay) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p, test.finalMulticastAddr) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPDisabledOnLoopback tests that the multicast group protocol is not -// performed on loopback interfaces since they have no neighbours. -func TestMGPDisabledOnLoopback(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) - - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - }) - } -} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD deleted file mode 100644 index fe98a52af..000000000 --- a/pkg/tcpip/ports/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ports", - srcs = [ - "flags.go", - "ports.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/header", - ], -) - -go_test( - name = "ports_test", - srcs = ["ports_test.go"], - library = ":ports", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/testutil", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go new file mode 100644 index 000000000..2719f6c41 --- /dev/null +++ b/pkg/tcpip/ports/ports_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package ports + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (f *Flags) StateTypeName() string { + return "pkg/tcpip/ports.Flags" +} + +func (f *Flags) StateFields() []string { + return []string{ + "MostRecent", + "LoadBalanced", + "TupleOnly", + } +} + +func (f *Flags) beforeSave() {} + +// +checklocksignore +func (f *Flags) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.MostRecent) + stateSinkObject.Save(1, &f.LoadBalanced) + stateSinkObject.Save(2, &f.TupleOnly) +} + +func (f *Flags) afterLoad() {} + +// +checklocksignore +func (f *Flags) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.MostRecent) + stateSourceObject.Load(1, &f.LoadBalanced) + stateSourceObject.Load(2, &f.TupleOnly) +} + +func init() { + state.Register((*Flags)(nil)) +} diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go deleted file mode 100644 index a91b130df..000000000 --- a/pkg/tcpip/ports/ports_test.go +++ /dev/null @@ -1,525 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ports - -import ( - "math" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 -) - -var ( - fakeIPAddress = testutil.MustParse4("8.8.8.8") - fakeIPAddress1 = testutil.MustParse4("8.8.8.9") -) - -type portReserveTestAction struct { - port uint16 - ip tcpip.Address - want tcpip.Error - flags Flags - release bool - device tcpip.NICID - dest tcpip.FullAddress -} - -func TestPortReservation(t *testing.T) { - for _, test := range []struct { - tname string - actions []portReserveTestAction - }{ - { - tname: "bind to ip", - actions: []portReserveTestAction{ - {port: 80, ip: fakeIPAddress, want: nil}, - {port: 80, ip: fakeIPAddress1, want: nil}, - /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, - }, - }, - { - tname: "bind to inaddr any", - actions: []portReserveTestAction{ - {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - /* release fakeIPAddress, but anyIPAddress is still inuse */ - {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, - /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ - {port: 22, ip: anyIPAddress, want: nil, release: true}, - {port: 22, ip: fakeIPAddress, want: nil}, - }, - }, { - tname: "bind to zero port", - actions: []portReserveTestAction{ - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to ip with reuseport", - actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - - {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind to inaddr any with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - - {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with device fails", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 1, want: nil}, - {port: 24, ip: fakeIPAddress, device: 2, want: nil}, - }, - }, { - tname: "bind to device and then without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind without device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "binding with reuseport and device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "mixing reuseport and not reuseport by binding to device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: nil}, - }, - }, { - tname: "can't bind to 0 after mixing reuseport and not reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind and release", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - - // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, - }, - }, { - tname: "bind twice with reuseport once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "release an unreserved device", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, - // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - // Release all. - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, - }, - }, { - tname: "bind with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind twice with reuseaddr once", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseaddr and reuseport, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr and reuseport twice, and then reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - }, - }, { - tname: "bind with reuseaddr, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind with reuseport, and then reuseaddr and reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, - }, - }, { - tname: "bind tuple with reuseaddr, and then wildcard", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - }, - }, { - tname: "bind tuple with reuseaddr, and then wildcard", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind two tuples with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, - }, - }, { - tname: "bind two tuples", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, - }, - }, { - tname: "bind wildcard, and then tuple with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, - }, - }, { - tname: "bind wildcard twice with reuseaddr", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, - }, - }, - } { - t.Run(test.tname, func(t *testing.T) { - pm := NewPortManager() - net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - - for _, test := range test.actions { - first, _ := pm.PortRange() - if test.release { - portRes := Reservation{ - Networks: net, - Transport: fakeTransNumber, - Addr: test.ip, - Port: test.port, - Flags: test.flags, - BindToDevice: test.device, - Dest: test.dest, - } - pm.ReleasePort(portRes) - continue - } - portRes := Reservation{ - Networks: net, - Transport: fakeTransNumber, - Addr: test.ip, - Port: test.port, - Flags: test.flags, - BindToDevice: test.device, - Dest: test.dest, - } - gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */) - if diff := cmp.Diff(test.want, err); diff != "" { - t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) - } - if test.port == 0 && (gotPort == 0 || gotPort < first) { - t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first) - } - } - }) - } -} - -func TestPickEphemeralPort(t *testing.T) { - const ( - firstEphemeral = 32000 - numEphemeralPorts = 1000 - ) - - for _, test := range []struct { - name string - f func(port uint16) (bool, tcpip.Error) - wantErr tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, tcpip.Error) { - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, tcpip.Error) { - return false, &tcpip.ErrBadBuffer{} - }, - wantErr: &tcpip.ErrBadBuffer{}, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, tcpip.Error) { - if port == firstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: firstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, tcpip.Error) { - if port < firstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { - t.Fatalf("failed to set ephemeral port range: %s", err) - } - port, err := pm.PickEphemeralPort(rng, test.f) - if diff := cmp.Diff(test.wantErr, err); diff != "" { - t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) - } - if port != test.wantPort { - t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) - } - }) - } -} - -func TestPickEphemeralPortStable(t *testing.T) { - const ( - firstEphemeral = 32000 - numEphemeralPorts = 1000 - ) - - for _, test := range []struct { - name string - f func(port uint16) (bool, tcpip.Error) - wantErr tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, tcpip.Error) { - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, tcpip.Error) { - return false, &tcpip.ErrBadBuffer{} - }, - wantErr: &tcpip.ErrBadBuffer{}, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, tcpip.Error) { - if port == firstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: firstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, tcpip.Error) { - if port < firstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: &tcpip.ErrNoPortAvailable{}, - }, - } { - t.Run(test.name, func(t *testing.T) { - pm := NewPortManager() - if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { - t.Fatalf("failed to set ephemeral port range: %s", err) - } - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - port, err := pm.PickEphemeralPortStable(portOffset, test.f) - if diff := cmp.Diff(test.wantErr, err); diff != "" { - t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) - } - if port != test.wantPort { - t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) - } - }) - } -} - -// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a -// port allocation failure. -func TestOverflow(t *testing.T) { - // Use a small range and start at offsets that will cause an overflow. - count := uint16(50) - for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ { - reservedPorts := make(map[uint16]struct{}) - // Ensure we can reserve everything in the allowed range. - for i := uint16(0); i < count; i++ { - port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) { - if _, ok := reservedPorts[port]; !ok { - reservedPorts[port] = struct{}{} - return true, nil - } - return false, nil - }) - if err != nil { - t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts)) - } - if port < firstEphemeral || port > firstEphemeral+count { - t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1) - } - } - } -} diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD deleted file mode 100644 index db9b91815..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_connect", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go deleted file mode 100644 index 009cab643..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build linux -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and connects to a peer. Similar to "nc <address> <port>". While the -// sample is running, attempts to connect to its IPv4 address will result in -// a RST segment. -// -// As an example of how to run it, a TUN device can be created and enabled on -// a linux host as follows (this only needs to be done once per boot): -// -// [sudo] ip tuntap add user <username> mode tun <device-name> -// [sudo] ip link set <device-name> up -// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name> -// -// A concrete example: -// -// $ sudo ip tuntap add user wedsonaf mode tun tun0 -// $ sudo ip link set tun0 up -// $ sudo ip addr add 192.168.1.1/24 dev tun0 -// -// Then one can run tun_tcp_connect as such: -// -// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234 -// -// This will attempt to connect to the linux host's stack. One can run nc in -// listen mode to accept a connect from tun_tcp_connect and exchange data. -package main - -import ( - "bytes" - "fmt" - "log" - "math/rand" - "net" - "os" - "strconv" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// writer reads from standard input and writes to the endpoint until standard -// input is closed. It signals that it's done by closing the provided channel. -func writer(ch chan struct{}, ep tcpip.Endpoint) { - defer func() { - ep.Shutdown(tcpip.ShutdownWrite) - close(ch) - }() - - var b bytes.Buffer - if err := func() error { - for { - if _, err := b.ReadFrom(os.Stdin); err != nil { - return fmt.Errorf("b.ReadFrom failed: %w", err) - } - - for b.Len() != 0 { - if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil { - return fmt.Errorf("ep.Write failed: %s", err) - } - } - } - }(); err != nil { - fmt.Println(err) - } -} - -func main() { - if len(os.Args) != 6 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>") - } - - tunName := os.Args[1] - addrName := os.Args[2] - portName := os.Args[3] - remoteAddrName := os.Args[4] - remotePortName := os.Args[5] - - rand.Seed(time.Now().UnixNano()) - - addr := tcpip.Address(net.ParseIP(addrName).To4()) - remote := tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()), - } - - var localPort uint16 - if v, err := strconv.Atoi(portName); err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } else { - localPort = uint16(v) - } - - if v, err := strconv.Atoi(remotePortName); err != nil { - log.Fatalf("Unable to convert port %v: %v", remotePortName, err) - } else { - remote.Port = uint16(v) - } - - // Create the stack with ipv4 and tcp protocols, then add a tun-based - // NIC and ipv4 address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - fd, err := tun.Open(tunName) - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - }) - - // Create TCP endpoint. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if e != nil { - log.Fatal(e) - } - - // Bind if a port is specified. - if localPort != 0 { - if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil { - log.Fatal("Bind failed: ", err) - } - } - - // Issue connect request and wait for it to complete. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.WritableEvents) - terr := ep.Connect(remote) - if _, ok := terr.(*tcpip.ErrConnectStarted); ok { - fmt.Println("Connect is pending...") - <-notifyCh - terr = ep.LastError() - } - wq.EventUnregister(&waitEntry) - - if terr != nil { - log.Fatal("Unable to connect: ", terr) - } - - fmt.Println("Connected") - - // Start the writer in its own goroutine. - writerCompletedCh := make(chan struct{}) - go writer(writerCompletedCh, ep) // S/R-SAFE: sample code. - - // Read data and write to standard output until the peer closes the - // connection from its side. - wq.EventRegister(&waitEntry, waiter.ReadableEvents) - for { - _, err := ep.Read(os.Stdout, tcpip.ReadOptions{}) - if err != nil { - if _, ok := err.(*tcpip.ErrClosedForReceive); ok { - break - } - - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-notifyCh - continue - } - - log.Fatal("Read() failed:", err) - } - } - wq.EventUnregister(&waitEntry) - - // The reader has completed. Now wait for the writer as well. - <-writerCompletedCh - - ep.Close() -} diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD deleted file mode 100644 index 43264b76d..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_echo", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go deleted file mode 100644 index c10b19aa0..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build linux -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and listens on a port. Data received by the server in the accepted -// connections is echoed back to the clients. -package main - -import ( - "bytes" - "flag" - "io" - "log" - "math/rand" - "net" - "os" - "strconv" - "strings" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var tap = flag.Bool("tap", false, "use tap istead of tun") -var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") - -type endpointWriter struct { - ep tcpip.Endpoint -} - -type tcpipError struct { - inner tcpip.Error -} - -func (e *tcpipError) Error() string { - return e.inner.String() -} - -func (e *endpointWriter) Write(p []byte) (int, error) { - var r bytes.Reader - r.Reset(p) - n, err := e.ep.Write(&r, tcpip.WriteOptions{}) - if err != nil { - return int(n), &tcpipError{ - inner: err, - } - } - if n != int64(len(p)) { - return int(n), io.ErrShortWrite - } - return int(n), nil -} - -func echo(wq *waiter.Queue, ep tcpip.Endpoint) { - defer ep.Close() - - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - - wq.EventRegister(&waitEntry, waiter.ReadableEvents) - defer wq.EventUnregister(&waitEntry) - - w := endpointWriter{ - ep: ep, - } - - for { - _, err := ep.Read(&w, tcpip.ReadOptions{}) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-notifyCh - continue - } - - return - } - } -} - -func main() { - flag.Parse() - if len(flag.Args()) != 3 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>") - } - - tunName := flag.Arg(0) - addrName := flag.Arg(1) - portName := flag.Arg(2) - - rand.Seed(time.Now().UnixNano()) - - // Parse the mac address. - maddr, err := net.ParseMAC(*mac) - if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) - } - - // Parse the IP address. Support both ipv4 and ipv6. - parsedAddr := net.ParseIP(addrName) - if parsedAddr == nil { - log.Fatalf("Bad IP address: %v", addrName) - } - - var addr tcpip.Address - var proto tcpip.NetworkProtocolNumber - if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) - proto = ipv4.ProtocolNumber - } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber - } else { - log.Fatalf("Unknown IP type: %v", addrName) - } - - localPort, err := strconv.Atoi(portName) - if err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } - - // Create the stack with ip and tcp protocols, then add a tun-based - // NIC and address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - var fd int - if *tap { - fd, err = tun.OpenTAP(tunName) - } else { - fd, err = tun.Open(tunName) - } - if err != nil { - log.Fatal(err) - } - - linkEP, err := fdbased.New(&fdbased.Options{ - FDs: []int{fd}, - MTU: mtu, - EthernetHeader: *tap, - Address: tcpip.LinkAddress(maddr), - }) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, linkEP); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) - if err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: 1, - }, - }) - - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if e != nil { - log.Fatal(e) - } - - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil { - log.Fatal("Bind failed: ", err) - } - - if err := ep.Listen(10); err != nil { - log.Fatal("Listen failed: ", err) - } - - // Wait for connections to appear. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.ReadableEvents) - defer wq.EventUnregister(&waitEntry) - - for { - n, wq, err := ep.Accept(nil) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-notifyCh - continue - } - - log.Fatal("Accept() failed:", err) - } - - go echo(wq, n) // S/R-SAFE: sample code. - } -} diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD deleted file mode 100644 index 45f503845..000000000 --- a/pkg/tcpip/seqnum/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "seqnum", - srcs = ["seqnum.go"], - visibility = ["//visibility:public"], -) diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go new file mode 100644 index 000000000..23e79811d --- /dev/null +++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seqnum diff --git a/pkg/tcpip/sock_err_list.go b/pkg/tcpip/sock_err_list.go new file mode 100644 index 000000000..0be1993af --- /dev/null +++ b/pkg/tcpip/sock_err_list.go @@ -0,0 +1,221 @@ +package tcpip + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type sockErrorElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (sockErrorElementMapper) linkerFor(elem *SockError) *SockError { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type sockErrorList struct { + head *SockError + tail *SockError +} + +// Reset resets list l to the empty state. +func (l *sockErrorList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *sockErrorList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *sockErrorList) Front() *SockError { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *sockErrorList) Back() *SockError { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *sockErrorList) Len() (count int) { + for e := l.Front(); e != nil; e = (sockErrorElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *sockErrorList) PushFront(e *SockError) { + linker := sockErrorElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + sockErrorElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *sockErrorList) PushBack(e *SockError) { + linker := sockErrorElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + sockErrorElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *sockErrorList) PushBackList(m *sockErrorList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + sockErrorElementMapper{}.linkerFor(l.tail).SetNext(m.head) + sockErrorElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *sockErrorList) InsertAfter(b, e *SockError) { + bLinker := sockErrorElementMapper{}.linkerFor(b) + eLinker := sockErrorElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + sockErrorElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *sockErrorList) InsertBefore(a, e *SockError) { + aLinker := sockErrorElementMapper{}.linkerFor(a) + eLinker := sockErrorElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + sockErrorElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *sockErrorList) Remove(e *SockError) { + linker := sockErrorElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + sockErrorElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + sockErrorElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type sockErrorEntry struct { + next *SockError + prev *SockError +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *sockErrorEntry) Next() *SockError { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *sockErrorEntry) Prev() *SockError { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *sockErrorEntry) SetNext(elem *SockError) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *sockErrorEntry) SetPrev(elem *SockError) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD deleted file mode 100644 index e0847e58a..000000000 --- a/pkg/tcpip/stack/BUILD +++ /dev/null @@ -1,152 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "most_shards") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "neighbor_entry_list", - out = "neighbor_entry_list.go", - package = "stack", - prefix = "neighborEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*neighborEntry", - "Linker": "*neighborEntry", - }, -) - -go_template_instance( - name = "packet_buffer_list", - out = "packet_buffer_list.go", - package = "stack", - prefix = "PacketBuffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*PacketBuffer", - "Linker": "*PacketBuffer", - }, -) - -go_template_instance( - name = "tuple_list", - out = "tuple_list.go", - package = "stack", - prefix = "tuple", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*tuple", - "Linker": "*tuple", - }, -) - -go_library( - name = "stack", - srcs = [ - "addressable_endpoint_state.go", - "conntrack.go", - "headertype_string.go", - "hook_string.go", - "icmp_rate_limit.go", - "iptables.go", - "iptables_state.go", - "iptables_targets.go", - "iptables_types.go", - "neighbor_cache.go", - "neighbor_entry.go", - "neighbor_entry_list.go", - "neighborstate_string.go", - "nic.go", - "nic_stats.go", - "nud.go", - "packet_buffer.go", - "packet_buffer_list.go", - "packet_buffer_unsafe.go", - "pending_packets.go", - "rand.go", - "registration.go", - "route.go", - "stack.go", - "stack_global_state.go", - "stack_options.go", - "tcp.go", - "transport_demuxer.go", - "tuple_list.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/atomicbitops", - "//pkg/buffer", - "//pkg/ilist", - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/transport/tcpconntrack", - "//pkg/waiter", - "@org_golang_x_time//rate:go_default_library", - ], -) - -go_test( - name = "stack_x_test", - size = "small", - srcs = [ - "addressable_endpoint_state_test.go", - "ndp_test.go", - "nud_test.go", - "stack_test.go", - "transport_demuxer_test.go", - "transport_test.go", - ], - shard_count = most_shards, - deps = [ - ":stack", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stack_test", - size = "small", - srcs = [ - "forwarding_test.go", - "neighbor_cache_test.go", - "neighbor_entry_test.go", - "nic_test.go", - "packet_buffer_test.go", - ], - library = ":stack", - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/testutil", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go deleted file mode 100644 index 140f146f6..000000000 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// TestAddressableEndpointStateCleanup tests that cleaning up an addressable -// endpoint state removes permanent addresses and leaves groups. -func TestAddressableEndpointStateCleanup(t *testing.T) { - var ep fakeNetworkEndpoint - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - var s stack.AddressableEndpointState - s.Init(&ep) - - addr := tcpip.AddressWithPrefix{ - Address: "\x01", - PrefixLen: 8, - } - - { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) - if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) - } - // We don't need the address endpoint. - ep.DecRef() - } - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep == nil { - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) - } - ep.DecRef() - } - - s.Cleanup() - if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } -} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go deleted file mode 100644 index 72f66441f..000000000 --- a/pkg/tcpip/stack/forwarding_test.go +++ /dev/null @@ -1,790 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "encoding/binary" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 - - // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match - // the MTU of loopback interfaces on linux systems. - fwdTestNetDefaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -// fwdTestNetworkEndpoint is a network-layer protocol endpoint. -// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fwdTestNetworkEndpoint struct { - AddressableEndpointState - - nic NetworkInterface - proto *fwdTestNetworkProtocol - dispatcher TransportDispatcher - - mu struct { - sync.RWMutex - forwarding bool - } -} - -func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { - return nil -} - -func (*fwdTestNetworkEndpoint) Enabled() bool { - return true -} - -func (*fwdTestNetworkEndpoint) Disable() {} - -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - if _, _, ok := f.proto.Parse(pkt); !ok { - return - } - - netHdr := pkt.NetworkHeader().View() - _, dst := f.proto.ParseAddresses(netHdr) - - addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) - if addressEndpoint != nil { - addressEndpoint.DecRef() - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) - return - } - - r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) - if err != nil { - return - } - defer r.Release() - - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - pkt = NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: vv.ToView().ToVectorisedView(), - }) - // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. - _ = r.WriteHeaderIncludedPacket(pkt) -} - -func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen -} - -func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) - b[dstAddrOffset] = r.RemoteAddress()[0] - b[srcAddrOffset] = r.LocalAddress()[0] - b[protocolNumberOffset] = byte(params.Protocol) - - return f.nic.WritePacket(r, fwdTestNetNumber, pkt) -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { - panic("not implemented") -} - -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { - // The network header should not already be populated. - if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return &tcpip.ErrMalformedHeader{} - } - - return f.nic.WritePacket(r, fwdTestNetNumber, pkt) -} - -func (f *fwdTestNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// Stats implements stack.NetworkEndpoint. -func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats { - return &fwdTestNetworkEndpointStats{} -} - -var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil) - -type fwdTestNetworkEndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} - -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) - -// fwdTestNetworkProtocol is a network-layer protocol that implements Address -// resolution. -type fwdTestNetworkProtocol struct { - stack *Stack - - neigh *neighborCache - addrResolveDelay time.Duration - onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) - onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) -} - -func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -func (*fwdTestNetworkProtocol) MinimumPacketSize() int { - return fwdTestNetHeaderLen -} - -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - -func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) - if !ok { - return 0, false, false - } - return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true -} - -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { - e := &fwdTestNetworkEndpoint{ - nic: nic, - proto: f, - dispatcher: dispatcher, - } - e.AddressableEndpointState.Init(e) - return e -} - -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} -} - -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} -} - -func (*fwdTestNetworkProtocol) Close() {} - -func (*fwdTestNetworkProtocol) Wait() {} - -func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { - if fn := f.proto.onLinkAddressResolved; fn != nil { - f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.neigh, addr, remoteLinkAddr) - }) - } - return nil -} - -func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if fn := f.proto.onResolveStaticAddress; fn != nil { - return fn(addr) - } - return "", false -} - -func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -// Forwarding implements stack.ForwardingNetworkEndpoint. -func (f *fwdTestNetworkEndpoint) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding - -} - -// SetForwarding implements stack.ForwardingNetworkEndpoint. -func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -// fwdTestPacketInfo holds all the information about an outbound packet. -type fwdTestPacketInfo struct { - RemoteLinkAddress tcpip.LinkAddress - LocalLinkAddress tcpip.LinkAddress - Pkt *PacketBuffer -} - -var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) - -type fwdTestLinkEndpoint struct { - dispatcher NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - - // C is where outbound packets are queued. - C chan fwdTestPacketInfo -} - -// InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - e.InjectLinkAddr(protocol, "", pkt) -} - -// InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) -} - -// Attach saves the stack network-layer dispatcher for use later when packets -// are injected. -func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *fwdTestLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *fwdTestLinkEndpoint) MTU() uint32 { - return e.mtu -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { - caps := LinkEndpointCapabilities(0) - return caps | CapabilityResolutionRequired -} - -// MaxHeaderLength returns the maximum size of the link layer header. Given it -// doesn't have a header, it just returns 0. -func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, - LocalLinkAddress: r.LocalLinkAddress, - Pkt: pkt, - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, protocol, pkt) - n++ - } - - return n, nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*fwdTestLinkEndpoint) Wait() {} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { - panic("not implemented") -} - -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { - clock := faketime.NewManualClock() - // Create a stack with the network protocol and two NICs. - s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { - proto.stack = s - return proto - }}, - Clock: clock, - }) - - protoNum := proto.Number() - if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) - } - - // NIC 1 has the link address "a", and added the network address 1. - ep1 := &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "a", - } - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) - } - - // NIC 2 has the link address "b", and added the network address 2. - ep2 := &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "b", - } - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } - - nic, ok := s.nics[2] - if !ok { - t.Fatal("NIC 2 does not exist") - } - - if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { - proto.neigh = &l.neigh - } - - // Route all packets to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) - } - - return clock, ep1, ep2 -} - -func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - clock, ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolver(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any address will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithNoResolver(t *testing.T) { - // Create a network protocol without a resolver. - proto := &fwdTestNetworkProtocol{} - - // Whether or not we use the neighbor cache here does not matter since - // neither linkAddrCache nor neighborCache will be used. - clock, ep1, ep2 := fwdTestNetFactory(t, proto) - - // inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - clock.Advance(proto.addrResolveDelay) - select { - case <-ep2.C: - t.Fatal("Packet should not be forwarded") - default: - } -} - -func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) { - // Don't resolve the link address. - }, - } - - clock, ep1, ep2 := fwdTestNetFactory(t, proto) - - const numPackets int = 5 - // These packets will all be enqueued in the packet queue to wait for link - // address resolution. - for i := 0; i < numPackets; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - // All packets should fail resolution. - for i := 0; i < numPackets; i++ { - clock.Advance(proto.addrResolveDelay) - select { - case got := <-ep2.C: - t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - default: - } - } -} - -func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go deleted file mode 100644 index 4d5431da1..000000000 --- a/pkg/tcpip/stack/ndp_test.go +++ /dev/null @@ -1,5569 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "bytes" - "encoding/binary" - "fmt" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - cryptorand "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - addr1 = testutil.MustParse6("a00::1") - addr2 = testutil.MustParse6("a00::2") - addr3 = testutil.MustParse6("a00::3") -) - -const ( - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - - defaultPrefixLen = 128 -) - -var ( - llAddr1 = header.LinkLocalAddr(linkAddr1) - llAddr2 = header.LinkLocalAddr(linkAddr2) - llAddr3 = header.LinkLocalAddr(linkAddr3) - llAddr4 = header.LinkLocalAddr(linkAddr4) - dstAddr = tcpip.FullAddress{ - Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - Port: 25, - } -) - -func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return tcpip.AddressWithPrefix{} - } - - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } -} - -// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent -// tcpip.Subnet, and an address where the lower half of the address is composed -// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. -func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { - prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixBytes), - PrefixLen: 64, - } - - subnet := prefix.Subnet() - - return prefix, subnet, addrForSubnet(subnet, linkAddr) -} - -// ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionResult. -type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - res stack.DADResult -} - -type ndpOffLinkRouteEvent struct { - nicID tcpip.NICID - subnet tcpip.Subnet - router tcpip.Address - prf header.NDPRoutePreference - // true if route was updated, false if invalidated. - updated bool -} - -type ndpPrefixEvent struct { - nicID tcpip.NICID - prefix tcpip.Subnet - // true if prefix was discovered, false if invalidated. - discovered bool -} - -type ndpAutoGenAddrEventType int - -const ( - newAddr ndpAutoGenAddrEventType = iota - deprecatedAddr - invalidatedAddr -) - -type ndpAutoGenAddrEvent struct { - nicID tcpip.NICID - addr tcpip.AddressWithPrefix - eventType ndpAutoGenAddrEventType -} - -func (e ndpAutoGenAddrEvent) String() string { - return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType) -} - -type ndpRDNSS struct { - addrs []tcpip.Address - lifetime time.Duration -} - -type ndpRDNSSEvent struct { - nicID tcpip.NICID - rdnss ndpRDNSS -} - -type ndpDNSSLEvent struct { - nicID tcpip.NICID - domainNames []string - lifetime time.Duration -} - -type ndpDHCPv6Event struct { - nicID tcpip.NICID - configuration ipv6.DHCPv6ConfigurationFromNDPRA -} - -var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) - -// ndpDispatcher implements NDPDispatcher so tests can know when various NDP -// related events happen for test purposes. -type ndpDispatcher struct { - dadC chan ndpDADEvent - offLinkRouteC chan ndpOffLinkRouteEvent - prefixC chan ndpPrefixEvent - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent - dnsslC chan ndpDNSSLEvent - routeTable []tcpip.Route - dhcpv6ConfigurationC chan ndpDHCPv6Event -} - -// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult. -func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) { - if n.dadC != nil { - n.dadC <- ndpDADEvent{ - nicID, - addr, - res, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. -func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { - if c := n.offLinkRouteC; c != nil { - c <- ndpOffLinkRouteEvent{ - nicID, - subnet, - router, - prf, - true, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. -func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { - if c := n.offLinkRouteC; c != nil { - var prf header.NDPRoutePreference - c <- ndpOffLinkRouteEvent{ - nicID, - subnet, - router, - prf, - false, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - true, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - false, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - newAddr, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - deprecatedAddr, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - invalidatedAddr, - } - } -} - -// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. -func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if c := n.rdnssC; c != nil { - c <- ndpRDNSSEvent{ - nicID, - ndpRDNSS{ - addrs, - lifetime, - }, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. -func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { - if n.dnsslC != nil { - n.dnsslC <- ndpDNSSLEvent{ - nicID, - domainNames, - lifetime, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { - if c := n.dhcpv6ConfigurationC; c != nil { - c <- ndpDHCPv6Event{ - nicID, - configuration, - } - } -} - -// channelLinkWithHeaderLength is a channel.Endpoint with a configurable -// header length. -type channelLinkWithHeaderLength struct { - *channel.Endpoint - headerLength uint16 -} - -func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { - return l.headerLength -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e)) -} - -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatal(err) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - -func TestDADResolveLoopback(t *testing.T) { - const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - dadConfigs := stack.DADConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 1, - } - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // DAD should not resolve after the normal resolution time since our DAD - // message was looped back - we should extend our DAD process. - dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer - clock.Advance(dadResolutionTime) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - - // Make sure the address does not resolve before the extended resolution time - // has passed. - const delta = time.Nanosecond - // DAD will send extra NS probes if an NS message is looped back. - const extraTransmits = 3 - clock.Advance(dadResolutionTime*extraTransmits - delta) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - - // DAD should now resolve. - clock.Advance(delta) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } -} - -// TestDADResolve tests that an address successfully resolves after performing -// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. -// Included in the subtests is a test to make sure that an invalid -// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. -// This tests also validates the NDP NS packet that is transmitted. -func TestDADResolve(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - dupAddrDetectTransmits uint8 - retransTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - name: "1:1s:1s", - dupAddrDetectTransmits: 1, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "2:1s:1s", - linkHeaderLen: 1, - dupAddrDetectTransmits: 2, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "1:2s:2s", - linkHeaderLen: 2, - dupAddrDetectTransmits: 1, - retransTimer: 2 * time.Second, - expectedRetransmitTimer: 2 * time.Second, - }, - // 0s is an invalid RetransmitTimer timer and will be fixed to - // the default RetransmitTimer value of 1s. - { - name: "1:0s:1s", - linkHeaderLen: 3, - dupAddrDetectTransmits: 1, - retransTimer: 0, - expectedRetransmitTimer: time.Second, - }, - } - - nonces := [][]byte{ - {1, 2, 3, 4, 5, 6}, - {7, 8, 9, 10, 11, 12}, - } - - var secureRNGBytes []byte - for _, n := range nonces { - secureRNGBytes = append(secureRNGBytes, n...) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - var secureRNG bytes.Reader - secureRNG.Reset(secureRNGBytes) - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - RandSource: rand.NewSource(time.Now().UnixNano()), - SecureRNG: &secureRNG, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: stack.DADConfigurations{ - RetransmitTimer: test.retransTimer, - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // We add a default route so the call to FindRoute below will succeed - // once we have an assigned address. - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: addr3, - NIC: nicID, - }}) - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Make sure the address does not resolve before the resolution time has - // passed. - const delta = time.Nanosecond - clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - // Should not get a route even if we specify the local address as the - // tentative address. - { - r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) - } - if r != nil { - r.Release() - } - } - { - r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) - } - if r != nil { - r.Release() - } - } - - if t.Failed() { - t.FailNow() - } - - // Wait for DAD to resolve. - clock.Advance(delta) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Error(err) - } - // Should get a route using the address now that it is resolved. - { - r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != nil { - t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err) - } else if r.LocalAddress() != addr1 { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), addr1) - } - r.Release() - } - { - r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != nil { - t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err) - } else if r.LocalAddress() != addr1 { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), addr1) - } - if r != nil { - r.Release() - } - } - - if t.Failed() { - t.FailNow() - } - - // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { - t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) - } - - // Validate the sent Neighbor Solicitation messages. - for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, ok := e.Read() - if !ok { - t.Fatal("packet didn't arrive") - } - - // Make sure its an IPv6 packet. - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - // Check NDP NS packet. - // - // As per RFC 4861 section 4.3, a possible option is the Source Link - // Layer option, but this option MUST NOT be included when the source - // address of the packet is the unspecified address. - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}), - )) - - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - }) - } -} - -func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: header.IPv6Any, - Dst: snmc, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) -} - -// TestDADFail tests to make sure that the DAD process fails if another node is -// detected to be performing DAD on the same address (receive an NS message from -// a node doing DAD for the same address), or if another node is detected to own -// the address already (receive an NA message for the tentative address). -func TestDADFail(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - rxPkt func(e *channel.Endpoint, tgt tcpip.Address) - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - expectedHolderLinkAddress tcpip.LinkAddress - }{ - { - name: "RxSolicit", - rxPkt: rxNDPSolicit, - getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborSolicit - }, - expectedHolderLinkAddress: "", - }, - { - name: "RxAdvert", - rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(tgt) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: tgt, - Dst: header.IPv6AllNodesMulticastAddress, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) - }, - getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborAdvert - }, - expectedHolderLinkAddress: linkAddr1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - dadConfigs := stack.DefaultDADConfigurations() - dadConfigs.RetransmitTimer = time.Second * 2 - - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Receive a packet to simulate an address conflict. - test.rxPkt(e, addr1) - - stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) - if got := stat.Value(); got != 1 { - t.Fatalf("got stat = %d, want = 1", got) - } - - // Wait for DAD to fail and make sure the address did - // not get resolved. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Attempting to add the address again should not fail if the address's - // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - }) - } -} - -func TestDADStop(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) - skipFinalAddrCheck bool - }{ - // Tests to make sure that DAD stops when an address is removed. - { - name: "Remove address", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is disabled. - { - name: "Disable NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is removed. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("RemoveNIC(%d): %s", nicID, err) - } - }, - // The NIC is removed so we can't check its addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - dadConfigs := stack.DADConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - test.stopFn(t, s) - - // Wait for DAD to fail (since the address was removed during DAD). - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - } - - if !test.skipFinalAddrCheck { - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - } - - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Errorf("got NeighborSolicit = %d, want <= 1", got) - } - }) - } -} - -// TestSetNDPConfigurations tests that we can update and use per-interface NDP -// configurations without affecting the default NDP configurations or other -// interfaces' configurations. -func TestSetNDPConfigurations(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const nicID3 = 3 - - tests := []struct { - name string - dupAddrDetectTransmits uint8 - retransmitTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - "OK", - 1, - time.Second, - time.Second, - }, - { - "Invalid Retransmit Timer", - 1, - 0, - time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("expected DAD event for %s", addr) - } - } - - // This NIC(1)'s NDP configurations will be updated to - // be different from the default. - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - - // Created before updating NIC(1)'s NDP configurations - // but updating NIC(1)'s NDP configurations should not - // affect other existing NICs. - if err := s.CreateNIC(nicID2, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - - // Update the configurations on NIC(1) to use DAD. - if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) - } else { - dad := ipv6Ep.(stack.DuplicateAddressDetector) - dad.SetDADConfigurations(stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - }) - } - - // Created after updating NIC(1)'s NDP configurations - // but the stack's default NDP configurations should not - // have been updated. - if err := s.CreateNIC(nicID3, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) - } - - // Add addresses for each NIC. - addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) - } - addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) - } - expectDADEvent(nicID2, addr2) - addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) - } - expectDADEvent(nicID3, addr3) - - // Address should not be considered bound to NIC(1) yet - // (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Should get the address on NIC(2) and NIC(3) - // immediately since we should not have performed DAD on - // it as the stack was configured to not do DAD by - // default and we only updated the NDP configurations on - // NIC(1). - if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatal(err) - } - if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatal(err) - } - - // Sleep until right before resolution to make sure the address didn't - // resolve on NIC(1) yet. - const delta = 1 - clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Wait for DAD to resolve. - clock.Advance(delta) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD resolution") - } - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatal(err) - } - }) - } -} - -// raBuf returns a valid NDP Router Advertisement with options, router -// preference and DHCPv6 configurations specified. -func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - const flagsByte = 1 - const routerLifetimeOffset = 2 - - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(0) - raPayload := pkt.MessageBody() - ra := header.NDPRouterAdvert(raPayload) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) - // Populate the Managed Address flag field. - if managedAddress { - // The Managed Addresses flag field is the 7th bit of the flags byte. - raPayload[flagsByte] |= 1 << 7 - } - // Populate the Other Configurations flag field. - if otherConfigurations { - // The Other Configurations flag field is the 6th bit of the flags byte. - raPayload[flagsByte] |= 1 << 6 - } - // The Prf field is held in the flags byte. - raPayload[flagsByte] |= byte(prf) << 3 - opts := ra.Options() - opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: ip, - Dst: header.IPv6AllNodesMulticastAddress, - })) - payloadLength := hdr.UsedLength() - iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) -} - -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) -} - -// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related -// fields set. -// -// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the -// DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { - return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) -} - -// raBuf returns a valid NDP Router Advertisement. -// -// Note, raBuf does not populate any of the RA fields other than the -// Router Lifetime. -func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) -} - -// raBufWithPrf returns a valid NDP Router Advertisement with a preference. -// -// Note, raBufWithPrf does not populate any of the RA fields other than the -// Router Lifetime and Default Router Preference fields. -func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { - return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) -} - -// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix -// Information option. -// -// Note, raBufWithPI does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { - flags := uint8(0) - if onLink { - // The OnLink flag is the 7th bit in the flags byte. - flags |= 1 << 7 - } - if auto { - // The Address Auto-Configuration flag is the 6th bit in the - // flags byte. - flags |= 1 << 6 - } - - // A valid header.NDPPrefixInformation must be 30 bytes. - buf := [30]byte{} - // The first byte in a header.NDPPrefixInformation is the Prefix Length - // field. - buf[0] = uint8(prefix.PrefixLen) - // The 2nd byte within a header.NDPPrefixInformation is the Flags field. - buf[1] = flags - // The Valid Lifetime field starts after the 2nd byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[2:], vl) - // The Preferred Lifetime field starts after the 6th byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[6:], pl) - // The Prefix Address field starts after the 14th byte within a - // header.NDPPrefixInformation. - copy(buf[14:], prefix.Address) - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ - header.NDPPrefixInformation(buf[:]), - }) -} - -// raBufWithRIO returns a valid NDP Router Advertisement with a single Route -// Information option. -// -// All fields in the RA will be zero except the RIO option. -func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer { - // buf will hold the route information option after the Type and Length - // fields. - // - // 2.3. Route Information Option - // - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Type | Length | Prefix Length |Resvd|Prf|Resvd| - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Route Lifetime | - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Prefix (Variable Length) | - // . . - // . . - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - var buf [22]byte - buf[0] = uint8(prefix.PrefixLen) - buf[1] = byte(prf) << 3 - binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds) - if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) { - t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address)) - } - return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{ - header.NDPRouteInformation(buf[:]), - }) -} - -func TestDynamicConfigurationsDisabled(t *testing.T) { - const ( - nicID = 1 - maxRtrSolicitDelay = time.Second - ) - - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - - tests := []struct { - name string - config func(bool) ipv6.NDPConfigurations - ra *stack.PacketBuffer - }{ - { - name: "No Router Discovery", - config: func(enable bool) ipv6.NDPConfigurations { - return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} - }, - ra: raBufSimple(llAddr2, 1000), - }, - { - name: "No Prefix Discovery", - config: func(enable bool) ipv6.NDPConfigurations { - return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} - }, - ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), - }, - { - name: "No Autogenerate Addresses", - config: func(enable bool) ipv6.NDPConfigurations { - return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} - }, - ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Being configured to discover routers/prefixes or auto-generate - // addresses means RAs must be handled, and router/prefix discovery or - // SLAAC must be enabled. - // - // This tests all possible combinations of the configurations where - // router/prefix discovery or SLAAC are disabled. - for i := 0; i < 7; i++ { - handle := ipv6.HandlingRAsDisabled - if i&1 != 0 { - handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled - } - enable := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - ndpConfigs := test.config(enable) - ndpConfigs.HandleRAs = handle - ndpConfigs.MaxRtrSolicitations = 1 - ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay - ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - e := channel.New(1, 1280, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding - ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) - } - stats := ep.Stats() - v6Stats, ok := stats.(*ipv6.Stats) - if !ok { - t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) - } - - // Make sure that when handling RAs are enabled, we solicit routers. - clock.Advance(maxRtrSolicitDelay) - if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { - t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) - } - if handleRAsDisabled { - if p, ok := e.Read(); ok { - t.Errorf("unexpectedly got a packet = %#v", p) - } - } else if p, ok := e.Read(); !ok { - t.Error("expected router solicitation packet") - } else if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } else { - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(nil)), - ) - } - - // Make sure we do not discover any routers or prefixes, or perform - // SLAAC on reception of an RA. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) - // Make sure that the unhandled RA stat is only incremented when - // handling RAs is disabled. - if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { - t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) - } - select { - case e := <-ndpDisp.offLinkRouteC: - t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) - default: - } - select { - case e := <-ndpDisp.prefixC: - t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) - default: - } - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) - default: - } - }) - } - }) - } -} - -func boolToUint64(v bool) uint64 { - if v { - return 1 - } - return 0 -} - -func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { - return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) -} - -func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { - tests := [...]struct { - name string - handleRAs ipv6.HandleRAsConfiguration - forwarding bool - }{ - { - name: "Handle RAs when forwarding disabled", - handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - forwarding: false, - }, - { - name: "Always Handle RAs with forwarding disabled", - handleRAs: ipv6.HandlingRAsAlwaysEnabled, - forwarding: false, - }, - { - name: "Always Handle RAs with forwarding enabled", - handleRAs: ipv6.HandlingRAsAlwaysEnabled, - forwarding: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f(t, test.handleRAs, test.forwarding) - }) - } -} - -func TestOffLinkRouteDiscovery(t *testing.T) { - const nicID = 1 - - moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16} - tests := []struct { - name string - - discoverDefaultRouters bool - discoverMoreSpecificRoutes bool - - dest tcpip.Subnet - ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer - }{ - { - name: "Default router discovery", - discoverDefaultRouters: true, - discoverMoreSpecificRoutes: false, - dest: header.IPv6EmptySubnet, - ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { - return raBufWithPrf(router, lifetimeSeconds, prf) - }, - }, - { - name: "More-specific route discovery", - discoverDefaultRouters: false, - discoverMoreSpecificRoutes: true, - dest: moreSpecificPrefix.Subnet(), - ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { - return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - DiscoverDefaultRouters: test.discoverDefaultRouters, - DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { - t.Helper() - - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - } - - expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - select { - case e := <-ndpDisp.offLinkRouteC: - var prf header.NDPRoutePreference - if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for router discovery event") - } - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference)) - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("unexpectedly updated an off-link route with 0 lifetime") - default: - } - - // Discover an off-link route through llAddr2. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference)) - if test.discoverMoreSpecificRoutes { - // The reserved value is considered invalid with more-specific route - // discovery so we inject the same packet but with the default - // (medium) preference value. - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("unexpectedly updated an off-link route with a reserved preference value") - default: - } - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference)) - } - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - - // Rx an RA from another router (lladdr3) with non-zero lifetime and - // non-default preference value. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) - expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - - // Rx an RA from lladdr2 with lesser lifetime and default (medium) - // preference value. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference)) - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") - default: - } - - // Rx an RA from lladdr2 with a different preference. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) - - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) - - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) - }) - }) - } -} - -// TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered. -func TestRouterDiscoveryMaxRouters(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ { - linkAddr := []byte{2, 2, 3, 4, 5, 0} - linkAddr[5] = byte(i) - llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - - if i <= ipv6.MaxDiscoveredOffLinkRoutes { - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - } else { - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - default: - } - } - } -} - -// Check e to make sure that the event is for prefix on nic with ID 1, and the -// discovered flag set to discovered. -func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { - return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) -} - -func TestPrefixDiscovery(t *testing.T) { - prefix1, subnet1, _ := prefixSubnetAddr(0, "") - prefix2, subnet2, _ := prefixSubnetAddr(1, "") - prefix3, subnet3, _ := prefixSubnetAddr(2, "") - - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) - - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) - - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) - - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } - - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - clock.Advance(time.Duration(lifetime) * time.Second) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) - }) -} - -func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - subnet := prefix.Subnet() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - // Receive an RA with prefix in an NDP Prefix Information option (PI) - // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) - expectPrefixEvent(subnet, true) - clock.Advance(header.NDPInfiniteLifetime) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) - clock.Advance(header.NDPInfiniteLifetime - time.Second) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) - expectPrefixEvent(subnet, true) - - // Receive an RA with prefix with an infinite lifetime. - // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) - clock.Advance(header.NDPInfiniteLifetime) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with 0 lifetime. - // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - expectPrefixEvent(subnet, false) -} - -// TestPrefixDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. -func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} - - // Receive an RA with 2 more than the max number of discovered on-link - // prefixes. - for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { - prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} - prefixAddr[7] = byte(i) - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixAddr[:]), - PrefixLen: 64, - } - prefixes[i] = prefix.Subnet() - buf := [30]byte{} - buf[0] = uint8(prefix.PrefixLen) - buf[1] = 128 - binary.BigEndian.PutUint32(buf[2:], 10) - copy(buf[14:], prefix.Address) - - optSer[i] = header.NDPPrefixInformation(buf[:]) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < ipv6.MaxDiscoveredOnLinkPrefixes { - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } else { - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - default: - } - } - } -} - -// Checks to see if list contains an IPv6 address, item. -func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: item, - } - - return containsAddr(list, protocolAddress) -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// event type is set to eventType. -func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { - return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) -} - -const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second) -const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second) - -// TestAutoGenAddr tests that an address is properly generated and invalidated -// when configured to do so. -func TestAutoGenAddr(t *testing.T) { - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } - - // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds - // the minimum. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } - - // Wait for addr of prefix1 to be invalidated. - clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - }) -} - -func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { - ret := "" - for _, c := range containList { - if !containsV6Addr(addrs, c) { - ret += fmt.Sprintf("should have %s in the list of addresses\n", c) - } - } - for _, c := range notContainList { - if containsV6Addr(addrs, c) { - ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) - } - } - return ret -} - -// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when -// configured to do so as part of IPv6 Privacy Extensions. -func TestAutoGenTempAddr(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - tests := []struct { - name string - dupAddrTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD disabled", - }, - { - name: "DAD enabled", - dupAddrTransmits: 1, - retransmitTimer: time.Second, - }, - } - - for i, test := range tests { - t.Run(test.name, func(t *testing.T) { - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, - MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - clock.RunImmediatelyScheduledJobs() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() - - clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds - // the minimum and won't be reached in this test. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } -} - -// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not -// generated for auto generated link-local addresses. -func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - dupAddrTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD disabled", - }, - { - name: "DAD enabled", - dupAddrTransmits: 1, - retransmitTimer: time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD event") - } - - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - default: - } - }) - } -} - -// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address -// will not be generated until after DAD completes, even if a new Router -// Advertisement is received to refresh lifetimes. -func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { - const ( - nicID = 1 - dadTransmits = 1 - retransmitTimer = 2 * time.Second - ) - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Receive an RA to trigger SLAAC for prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // DAD on the stable address for prefix has not yet completed. Receiving a new - // RA that would refresh lifetimes should not generate a temporary SLAAC - // address for the prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } - - // Wait for DAD to complete for the stable address then expect the temporary - // address to be generated. - clock.Advance(dadTransmits * retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD event") - } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } -} - -// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are -// regenerated. -func TestAutoGenTempAddrRegen(t *testing.T) { - const ( - nicID = 1 - regenAdv = 2 * time.Second - - numTempAddrs = 3 - maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate - ) - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix - for i := 0; i < len(tempAddrs); i++ { - tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - } - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: regenAdv, - MaxTempAddrValidLifetime: maxTempAddrValidLifetime, - MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate, - } - clock := faketime.NewManualClock() - randSource := savingRandSource{ - s: rand.NewSource(time.Now().UnixNano()), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - RandSource: &randSource, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - } - - tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor - effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor - // The time since the last regeneration before a new temporary address is - // generated. - tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) - expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddrs[0], newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" { - t.Fatal(mismatch) - } - expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) - - // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv) - expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) - - // Stop generating temporary addresses - ndpConfigs.AutoGenTempGlobalAddresses = false - if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) - } - - // Refresh lifetimes and wait for the last temporary address to be deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) - expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv) - - // Refresh lifetimes such that the prefix is valid and preferred forever. - // - // This should not affect the lifetimes of temporary addresses because they - // are capped by the maximum valid and preferred lifetimes for temporary - // addresses. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) - - // Wait for all the temporary addresses to get invalidated. - invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{}) - for _, addr := range tempAddrs { - expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter) - invalidateAfter = tempAddrRegenenerationTime - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" { - t.Fatal(mismatch) - } -} - -// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's -// regeneration job gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { - const ( - nicID = 1 - regenAdv = 2 * time.Second - - numTempAddrs = 3 - maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate - maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second) - ) - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix - for i := 0; i < len(tempAddrs); i++ { - tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - } - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: regenAdv, - MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime, - MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2, - } - clock := faketime.NewManualClock() - initialTime := clock.NowMonotonic() - randSource := savingRandSource{ - s: rand.NewSource(time.Now().UnixNano()), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - RandSource: &randSource, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) - expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddrs[0], newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Deprecate the prefix. - // - // A new temporary address should be generated after the regeneration - // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %#v", e) - default: - } - - effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor - // The time since the last regeneration before a new temporary address is - // generated. - tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv - - // Advance the clock by the regeneration time but don't expect a new temporary - // address as the prefix is deprecated. - clock.Advance(tempAddrRegenenerationTime) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %#v", e) - default: - } - - // Prefer the prefix again. - // - // A new temporary address should immediately be generated since the - // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) - expectAutoGenAddrEvent(tempAddrs[1], newAddr) - // Wait for the first temporary address to be deprecated. - expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %s", e) - default: - } - - // Increase the maximum lifetimes for temporary addresses to large values - // then refresh the lifetimes of the prefix. - // - // A new address should not be generated after the regeneration time that was - // expected for the previous check. This is because the preferred lifetime for - // the temporary addresses has increased, so it will take more time to - // regenerate a new temporary address. Note, new addresses are only - // regenerated after the preferred lifetime - the regenerate advance duration - // as paased. - const largeLifetimeSeconds = minVLSeconds * 2 - const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second - ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime - ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime - ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime) - clock.Advance(largeLifetime - timeSinceInitialTime) - expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) - // to offset the advement of time to test the first temporary address's - // deprecation after the second was generated - advLess := regenAdv - expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv)) - expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } -} - -// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response -// to a mix of DAD conflicts and NIC-local conflicts. -func TestMixedSLAACAddrConflictRegen(t *testing.T) { - const ( - nicID = 1 - nicName = "nic" - lifetimeSeconds = 9999 - // From stack.maxSLAACAddrLocalRegenAttempts - maxSLAACAddrLocalRegenAttempts = 10 - // We use 2 more addreses than the maximum local regeneration attempts - // because we want to also trigger regeneration in response to a DAD - // conflicts for this test. - maxAddrs = maxSLAACAddrLocalRegenAttempts + 2 - dupAddrTransmits = 1 - retransmitTimer = time.Second - ) - - var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte - header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID) - - var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte - header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID) - - prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1) - var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix - var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix - var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix - addrBytes := []byte(subnet.ID()) - for i := 0; i < maxAddrs; i++ { - stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)), - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - // When generating temporary addresses, the resolved stable address for the - // SLAAC prefix will be the first address stable address generated for the - // prefix as we will not simulate address conflicts for the stable addresses - // in tests involving temporary addresses. Address conflicts for stable - // addresses will be done in their own tests. - tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address) - tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address) - } - - tests := []struct { - name string - addrs []tcpip.AddressWithPrefix - tempAddrs bool - initialExpect tcpip.AddressWithPrefix - maxAddrs int - nicNameFromID func(tcpip.NICID, string) string - }{ - { - name: "Stable addresses with opaque IIDs", - addrs: stableAddrsWithOpaqueIID[:], - maxAddrs: 1, - nicNameFromID: func(tcpip.NICID, string) string { - return nicName - }, - }, - { - name: "Temporary addresses with opaque IIDs", - addrs: tempAddrsWithOpaqueIID[:], - tempAddrs: true, - initialExpect: stableAddrsWithOpaqueIID[0], - maxAddrs: 1 /* initial (stable) address */ + maxSLAACAddrLocalRegenAttempts, - nicNameFromID: func(tcpip.NICID, string) string { - return nicName - }, - }, - { - name: "Temporary addresses with modified EUI64", - addrs: tempAddrsWithModifiedEUI64[:], - tempAddrs: true, - maxAddrs: 1 /* initial (stable) address */ + maxSLAACAddrLocalRegenAttempts, - initialExpect: stableAddrWithModifiedEUI64, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - // We may receive a deprecated and invalidated event for each SLAAC - // address that is assigned. - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAddrs*2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: test.tempAddrs, - AutoGenAddressConflictRetries: 1, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr2, - NIC: nicID, - }}) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - manuallyAssignedAddresses := make(map[tcpip.Address]struct{}) - for j := 0; j < len(test.addrs)-1; j++ { - // The NIC will not attempt to generate an address in response to a - // NIC-local conflict after some maximum number of attempts. We skip - // creating a conflict for the address that would be generated as part - // of the last attempt so we can simulate a DAD conflict for this - // address and restart the NIC-local generation process. - if j == maxSLAACAddrLocalRegenAttempts-1 { - continue - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) - } - - manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - if diff := checkAutoGenAddrEvent(<-ndpDisp.autoGenAddrC, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - } - - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() - - clock.Advance(dupAddrTransmits * retransmitTimer) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - } - - // Enable DAD. - ndpDisp.dadC = make(chan ndpDADEvent, 2) - if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else { - ndpEP := ipv6Ep.(stack.DuplicateAddressDetector) - ndpEP.SetDADConfigurations(stack.DADConfigurations{ - DupAddrDetectTransmits: dupAddrTransmits, - RetransmitTimer: retransmitTimer, - }) - } - - // Do SLAAC for prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - if test.initialExpect != (tcpip.AddressWithPrefix{}) { - expectAutoGenAddrEvent(test.initialExpect, newAddr) - expectDADEventAsync(test.initialExpect.Address) - } - - // The last local generation attempt should succeed, but we introduce a - // DAD failure to restart the local generation process. - addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] - expectAutoGenAddrAsyncEvent(addr, newAddr) - rxNDPSolicit(e, addr.Address) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // The last address generated should resolve DAD. - addr = test.addrs[len(test.addrs)-1] - expectAutoGenAddrAsyncEvent(addr, newAddr) - expectDADEventAsync(addr.Address) - - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } - - // Wait for all the SLAAC addresses to be invalidated. - clock.Advance(lifetimeSeconds * time.Second) - gotAddresses := make(map[tcpip.Address]struct{}) - for _, a := range s.NICInfo()[nicID].ProtocolAddresses { - gotAddresses[a.AddressWithPrefix.Address] = struct{}{} - } - if diff := cmp.Diff(manuallyAssignedAddresses, gotAddresses); diff != "" { - t.Fatalf("assigned addresses mismatch (-want +got):\n%s", diff) - } - }) - } -} - -// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, -// channel.Endpoint and stack.Stack. -// -// stack.Stack will have a default route through the router (llAddr3) installed -// and a static link-address (linkAddr3) added to the link address cache for the -// router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - ndpDisp := &ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - - if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) - } - return ndpDisp, e, s, clock -} - -// addrForNewConnectionTo returns the local address used when creating a new -// connection to addr. -func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(addr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", addr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { - t.Helper() - - return addrForNewConnectionTo(t, s, dstAddr) -} - -// addrForNewConnectionWithAddr returns the local address used when creating a -// new connection with a specific local address. -func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - if err := ep.Bind(addr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", addr, err) - } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when -// receiving a PI with 0 preferred lifetime. -func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) - - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) - - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } - - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } - - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) -} - -// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated -// when its preferred lifetime expires. -func TestAutoGenAddrJobDeprecation(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) - - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - // - // Since we're about to cause both events to fire, we need the dispatcher - // channel to be able to hold both. - if got, want := len(ndpDisp.autoGenAddrC), 0; got != want { - t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want) - } - if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want { - t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want) - } - ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2) - clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - } else { - t.Fatalf("got unexpected auto-generated event") - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - - { - err := ep.Connect(dstAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) - } - } -} - -// Tests transitioning a SLAAC address's valid lifetime between finite and -// infinite values. -func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { - const infiniteVLSeconds = 2 - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - infiniteVL uint32 - }{ - { - name: "EqualToInfiniteVL", - infiniteVL: infiniteVLSeconds, - }, - // Our implementation supports changing header.NDPInfiniteLifetime for tests - // such that a packet can be received where the lifetime field has a value - // greater than header.NDPInfiniteLifetime. Because of this, we test to make - // sure that receiving a value greater than header.NDPInfiniteLifetime is - // handled the same as when receiving a value equal to - // header.NDPInfiniteLifetime. - { - name: "MoreThanInfiniteVL", - infiniteVL: infiniteVLSeconds + 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - - clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - default: - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } -} - -// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an -// auto-generated address only gets updated when required to, as specified in -// RFC 4862 section 5.5.3.e. -func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { - const infiniteVL = 4294967295 - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - ovl uint32 - nvl uint32 - evl uint32 - }{ - // Should update the VL to the minimum VL for updating if the - // new VL is less than minVLSeconds but was originally greater than - // it. - { - "LargeVLToVLLessThanMinVLForUpdate", - 9999, - 1, - minVLSeconds, - }, - { - "LargeVLTo0", - 9999, - 0, - minVLSeconds, - }, - { - "InfiniteVLToVLLessThanMinVLForUpdate", - infiniteVL, - 1, - minVLSeconds, - }, - { - "InfiniteVLTo0", - infiniteVL, - 0, - minVLSeconds, - }, - - // Should not update VL if original VL was less than minVLSeconds - // and the new VL is also less than minVLSeconds. - { - "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - minVLSeconds - 1, - minVLSeconds - 3, - minVLSeconds - 1, - }, - - // Should take the new VL if the new VL is greater than the - // remaining time or is greater than minVLSeconds. - { - "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - minVLSeconds + 5, - minVLSeconds + 3, - minVLSeconds + 3, - }, - { - "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - minVLSeconds - 3, - minVLSeconds - 1, - minVLSeconds - 1, - }, - { - "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - minVLSeconds - 3, - minVLSeconds + 1, - minVLSeconds + 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - - // - // Validate that the VL for the address got set - // to test.evl. - // - - // The address should not be invalidated until the effective valid - // lifetime has passed. - const delta = 1 - clock.Advance(time.Duration(test.evl)*time.Second - delta) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - default: - } - - // Wait for the invalidation event. - clock.Advance(delta) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } -} - -// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed -// by the user, its resources will be cleaned up and an invalidation event will -// be sent to the integrator. -func TestAutoGenAddrRemoval(t *testing.T) { - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive a PI to auto-generate an address. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr, newAddr) - - // Removing the address should result in an invalidation event - // immediately. - if err := s.RemoveAddress(1, addr.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // Wait for the original valid lifetime to make sure the original job got - // cancelled/cleaned up. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - default: - } -} - -// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously -// assigned to the NIC but is in the permanentExpired state. -func TestAutoGenAddrAfterRemoval(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) - - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) -} - -// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that -// is already assigned to the NIC, the static address remains. -func TestAutoGenAddrStaticConflict(t *testing.T) { - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive a PI where the generated address will be the same as the one - // that we already have assigned statically. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - default: - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Should not get an invalidation event after the PI's invalidation - // time. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - default: - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } -} - -func makeSecretKey(t *testing.T) []byte { - secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) - n, err := cryptorand.Read(secretKey) - if err != nil { - t.Fatalf("cryptorand.Read(_): %s", err) - } - if l := len(secretKey); n != l { - t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) - } - return secretKey -} - -// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use -// opaque interface identifiers when configured to do so. -func TestAutoGenAddrWithOpaqueIID(t *testing.T) { - const nicID = 1 - const nicName = "nic1" - - secretKey := makeSecretKey(t) - - prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) - // addr1 and addr2 are the addresses that are expected to be generated when - // stack.Stack is configured to generate opaque interface identifiers as - // defined by RFC 7217. - addrBytes := []byte(subnet1.ID()) - addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), - PrefixLen: 64, - } - addrBytes = []byte(subnet2.ID()) - addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), - PrefixLen: 64, - } - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - Clock: clock, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in a PI. - const validLifetimeSecondPrefix1 = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in a PI with a large valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - - // Wait for addr of prefix1 to be invalidated. - clock.Advance(validLifetimeSecondPrefix1 * time.Second) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } -} - -func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { - const nicID = 1 - const nicName = "nic" - const dadTransmits = 1 - const retransmitTimer = time.Second - const maxMaxRetries = 3 - const lifetimeSeconds = 10 - - secretKey := makeSecretKey(t) - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { - addrBytes := []byte(subnet.ID()) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), - PrefixLen: 64, - } - } - - expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - clock.RunImmediatelyScheduledJobs() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { - t.Helper() - - clock.RunImmediatelyScheduledJobs() - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, res); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - } - - expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { - t.Helper() - - clock.Advance(dadTransmits * retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, res); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD event") - } - } - - stableAddrForTempAddrTest := addrForSubnet(subnet, 0) - - addrTypes := []struct { - name string - ndpConfigs ipv6.NDPConfigurations - autoGenLinkLocal bool - prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix - addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix - }{ - { - name: "Global address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - return nil - - }, - addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { - return addrForSubnet(subnet, dadCounter) - }, - }, - { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{}, - autoGenLinkLocal: true, - prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { - return nil - }, - addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { - return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) - }, - }, - { - name: "Temporary address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { - header.InitialTempIID(tempIIDHistory, nil, nicID) - - // Generate a stable SLAAC address so temporary addresses will be - // generated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) - - // The stable address will be assigned throughout the test. - return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} - }, - addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) - }, - }, - } - - for _, addrType := range addrTypes { - t.Run(addrType.name, func(t *testing.T) { - for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { - for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { - maxRetries := maxRetries - numFailures := numFailures - addrType := addrType - - t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := addrType.ndpConfigs - ndpConfigs.AutoGenAddressConflictRetries = maxRetries - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: addrType.autoGenLinkLocal, - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - Clock: clock, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:]) - - // Simulate DAD conflicts so the address is regenerated. - for i := uint8(0); i < numFailures; i++ { - addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) - - // Should not have any new addresses assigned to the NIC. - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Simulate a DAD conflict. - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) - - // Attempting to add the address manually should not fail if the - // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) - } - if err := s.RemoveAddress(nicID, addr.Address); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) - } - expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{}) - } - - // Should not have any new addresses assigned to the NIC. - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // If we had less failures than generation attempts, we should have - // an address after DAD resolves. - if maxRetries+1 > numFailures { - addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{}) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { - t.Fatal(mismatch) - } - } - - // Should not attempt address generation again. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - default: - } - }) - } - } - }) - } -} - -// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is -// not made for SLAAC addresses generated with an IID based on the NIC's link -// address. -func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { - const nicID = 1 - const dadTransmits = 1 - const retransmitTimer = time.Second - const maxRetries = 3 - const lifetimeSeconds = 10 - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - addrTypes := []struct { - name string - ndpConfigs ipv6.NDPConfigurations - autoGenLinkLocal bool - subnet tcpip.Subnet - triggerSLAACFn func(e *channel.Endpoint) - }{ - { - name: "Global address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - subnet: subnet, - triggerSLAACFn: func(e *channel.Endpoint) { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - - }, - }, - { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{ - AutoGenAddressConflictRetries: maxRetries, - }, - autoGenLinkLocal: true, - subnet: header.IPv6LinkLocalPrefix.Subnet(), - triggerSLAACFn: func(e *channel.Endpoint) {}, - }, - } - - for _, addrType := range addrTypes { - addrType := addrType - - t.Run(addrType.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - addrType.triggerSLAACFn(e) - - addrBytes := []byte(addrType.subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) - - // Simulate a DAD conflict. - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - - // Should not attempt address regeneration. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - default: - } - }) - } -} - -// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address -// generation in response to DAD conflicts does not refresh the lifetimes. -func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { - const nicID = 1 - const nicName = "nic" - const dadTransmits = 1 - const retransmitTimer = 2 * time.Second - const failureTimer = time.Second - const maxRetries = 1 - const lifetimeSeconds = 5 - - secretKey := makeSecretKey(t) - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - Clock: clock, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - - addrBytes := []byte(subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) - - // Simulate a DAD conflict after some time has passed. - clock.Advance(failureTimer) - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - - // Let the next address resolve. - addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) - expectAutoGenAddrEvent(addr, newAddr) - clock.Advance(dadTransmits * retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD event") - } - - // Address should be deprecated/invalidated after the lifetime expires. - // - // Note, the remaining lifetime is calculated from when the PI was first - // processed. Since we wait for some time before simulating a DAD conflict - // and more time for the new address to resolve, the new address is only - // expected to be valid for the remaining time. The DAD conflict should - // not have reset the lifetimes. - // - // We expect either just the invalidation event or the deprecation event - // followed by the invalidation event. - clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) - select { - case e := <-ndpDisp.autoGenAddrC: - if e.eventType == deprecatedAddr { - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") - } - } else { - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - } - default: - t.Fatal("timed out waiting for auto gen addr event") - } -} - -// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event -// to the integrator when an RA is received with the NDP Recursive DNS Server -// option with at least one valid address. -func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - tests := []struct { - name string - opt header.NDPRecursiveDNSServer - expected *ndpRDNSS - }{ - { - "Unspecified", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }), - nil, - }, - { - "Multicast", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }), - nil, - }, - { - "OptionTooSmall", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, - }), - nil, - }, - { - "0Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - }), - nil, - }, - { - "Valid1Address", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - }, - 2 * time.Second, - }, - }, - { - "Valid2Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - }, - time.Second, - }, - }, - { - "Valid3Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 0, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", - }, - 0, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - // We do not expect more than a single RDNSS - // event at any time for this test. - rdnssC: make(chan ndpRDNSSEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - }, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) - - if test.expected != nil { - select { - case e := <-ndpDisp.rdnssC: - if e.nicID != 1 { - t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) - } - if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { - t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) - } - if e.rdnss.lifetime != test.expected.lifetime { - t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) - } - default: - t.Fatal("expected an RDNSS option event") - } - } - - // Should have no more RDNSS options. - select { - case e := <-ndpDisp.rdnssC: - t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) - default: - } - }) - } -} - -// TestNDPDNSSearchListDispatch tests that the integrator is informed when an -// NDP DNS Search List option is received with at least one domain name in the -// search list. -func TestNDPDNSSearchListDispatch(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dnsslC: make(chan ndpDNSSLEvent, 3), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - }, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - optSer := header.NDPOptionsSerializer{ - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 0, 0, - 2, 'h', 'i', - 0, - }), - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 'i', - 0, - 2, 'a', 'm', - 2, 'm', 'e', - 0, - }), - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 1, 0, - 3, 'x', 'y', 'z', - 0, - 5, 'h', 'e', 'l', 'l', 'o', - 5, 'w', 'o', 'r', 'l', 'd', - 0, - 4, 't', 'h', 'i', 's', - 2, 'i', 's', - 1, 'a', - 4, 't', 'e', 's', 't', - 0, - }), - } - expected := []struct { - domainNames []string - lifetime time.Duration - }{ - { - domainNames: []string{ - "hi", - }, - lifetime: 0, - }, - { - domainNames: []string{ - "i", - "am.me", - }, - lifetime: time.Second, - }, - { - domainNames: []string{ - "xyz", - "hello.world", - "this.is.a.test", - }, - lifetime: 256 * time.Second, - }, - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - - for i, expected := range expected { - select { - case dnssl := <-ndpDisp.dnsslC: - if dnssl.nicID != nicID { - t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) - } - if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { - t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) - } - if dnssl.lifetime != expected.lifetime { - t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) - } - default: - t.Fatal("expected a DNSSL event") - } - } - - // Should have no more DNSSL options. - select { - case <-ndpDisp.dnsslC: - t.Fatal("unexpectedly got a DNSSL event") - default: - } -} - -func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { - const ( - lifetimeSeconds = 999 - nicID = 1 - ) - - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) - } - - prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) - e1.InjectInbound( - header.IPv6ProtocolNumber, - raBufWithPI( - llAddr3, - lifetimeSeconds, - prefix, - true, /* onLink */ - true, /* auto */ - lifetimeSeconds, - lifetimeSeconds, - ), - ) - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) - } - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) - } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) - } - - // Enabling or disabling forwarding should not invalidate discovered prefixes - // or routers, or auto-generated address. - for _, forwarding := range [...]bool{true, false} { - t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - select { - case e := <-ndpDisp.offLinkRouteC: - t.Errorf("unexpected off-link route event = %#v", e) - default: - } - select { - case e := <-ndpDisp.prefixC: - t.Errorf("unexpected prefix event = %#v", e) - default: - } - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("unexpected auto-gen addr event = %#v", e) - default: - } - }) - } -} - -func TestCleanupNDPState(t *testing.T) { - const ( - lifetimeSeconds = 5 - maxRouterAndPrefixEvents = 4 - nicID1 = 1 - nicID2 = 2 - ) - - prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) - e2Addr1 := addrForSubnet(subnet1, linkAddr2) - e2Addr2 := addrForSubnet(subnet2, linkAddr2) - llAddrWithPrefix1 := tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 64, - } - llAddrWithPrefix2 := tcpip.AddressWithPrefix{ - Address: llAddr2, - PrefixLen: 64, - } - - tests := []struct { - name string - cleanupFn func(t *testing.T, s *stack.Stack) - keepAutoGenLinkLocal bool - maxAutoGenAddrEvents int - skipFinalAddrCheck bool - }{ - // A NIC should cleanup all NDP state when it is disabled. - { - name: "Disable NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - }, - - // A NIC should cleanup all NDP state when it is removed. - { - name: "Remove NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.RemoveNIC(nicID1); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - // The NICs are removed so we can't check their addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), - } - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { - select { - case e := <-ndpDisp.offLinkRouteC: - return true, e - default: - } - - return false, ndpOffLinkRouteEvent{} - } - - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } - - return false, ndpPrefixEvent{} - } - - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } - - return false, ndpAutoGenAddrEvent{} - } - - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() - - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() - - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } - - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } - - test.cleanupFn(t, s) - - // Collect invalidation events after having NDP state cleaned up. - gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectOffLinkRouteEvent() - if !ok { - t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotOffLinkRouteEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < test.maxAutoGenAddrEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } - - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } - - expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ - {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, - {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, - {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, - {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, - } - if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { - t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - - if !test.keepAutoGenLinkLocal { - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 - } - - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } - - if !test.skipFinalAddrCheck { - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - } - - // Should not get any more events (invalidation timers should have been - // cancelled when the NDP state was cleaned up). - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.offLinkRouteC: - t.Error("unexpected off-link route event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: - } - }) - } -} - -// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly -// informed when new information about what configurations are available via -// DHCPv6 is learned. -func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { - t.Helper() - select { - case e := <-ndpDisp.dhcpv6ConfigurationC: - if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { - t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DHCPv6 configuration event") - } - } - - expectNoDHCPv6Event := func() { - t.Helper() - select { - case <-ndpDisp.dhcpv6ConfigurationC: - t.Fatal("unexpected DHCPv6 configuration event") - default: - } - } - - // Even if the first RA reports no DHCPv6 configurations are available, the - // dispatcher should get an event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) - // Receiving the same update again should not result in an event to the - // dispatcher. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to none. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - // - // Note, when the M flag is set, the O flag is redundant. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - // Even though the DHCPv6 flags are different, the effective configuration is - // the same so we should not receive a new event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() -} - -var _ rand.Source = (*savingRandSource)(nil) - -type savingRandSource struct { - s rand.Source - - lastInt63 int64 -} - -func (d *savingRandSource) Int63() int64 { - i := d.s.Int63() - d.lastInt63 = i - return i -} -func (d *savingRandSource) Seed(seed int64) { - d.s.Seed(seed) -} - -// TestRouterSolicitation tests the initial Router Solicitations that are sent -// when a NIC newly becomes enabled. -func TestRouterSolicitation(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - linkAddr tcpip.LinkAddress - nicAddr tcpip.Address - expectedSrcAddr tcpip.Address - expectedNDPOpts []header.NDPOption - maxRtrSolicit uint8 - rtrSolicitInt time.Duration - effectiveRtrSolicitInt time.Duration - maxRtrSolicitDelay time.Duration - effectiveMaxRtrSolicitDelay time.Duration - }{ - { - name: "Single RS with 2s delay and interval", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 2 * time.Second, - effectiveMaxRtrSolicitDelay: 2 * time.Second, - }, - { - name: "Single RS with 4s delay and interval", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: 4 * time.Second, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 4 * time.Second, - effectiveMaxRtrSolicitDelay: 4 * time.Second, - }, - { - name: "Two RS with delay", - linkHeaderLen: 1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - maxRtrSolicit: 2, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 500 * time.Millisecond, - effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, - }, - { - name: "Single RS without delay", - linkHeaderLen: 2, - linkAddr: linkAddr1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - expectedNDPOpts: []header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }, - maxRtrSolicit: 1, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS without delay and invalid zero interval", - linkHeaderLen: 3, - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: 0, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Three RS without delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 3, - rtrSolicitInt: 500 * time.Millisecond, - effectiveRtrSolicitInt: 500 * time.Millisecond, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS with invalid negative delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: -3 * time.Second, - effectiveMaxRtrSolicitDelay: time.Second, - }, - } - - subTests := []struct { - name string - handleRAs ipv6.HandleRAsConfiguration - afterFirstRS func(*testing.T, *stack.Stack) - }{ - { - name: "Handle RAs when forwarding disabled", - handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - afterFirstRS: func(*testing.T, *stack.Stack) {}, - }, - - // Enabling forwarding when RAs are always configured to be handled - // should not stop router solicitations. - { - name: "Handle RAs always", - handleRAs: ipv6.HandlingRAsAlwaysEnabled, - afterFirstRS: func(t *testing.T, s *stack.Stack) { - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) - - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - randSource := savingRandSource{ - s: rand.NewSource(time.Now().UnixNano()), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: subTest.handleRAs, - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - RandSource: &randSource, - }) - - opts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, &e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) - } - - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining != 0 { - maxRtrSolicitDelay := test.maxRtrSolicitDelay - if maxRtrSolicitDelay < 0 { - maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay - } - var actualRtrSolicitDelay time.Duration - if maxRtrSolicitDelay != 0 { - actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay - } - waitForPkt(actualRtrSolicitDelay) - remaining-- - } - - subTest.afterFirstRS(t, s) - - for ; remaining != 0; remaining-- { - if test.effectiveRtrSolicitInt != 0 { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(0) - } - } - - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } - - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) - } -} - -func TestStopStartSolicitingRouters(t *testing.T) { - const nicID = 1 - const delay = 0 - const interval = 500 * time.Millisecond - const maxRtrSolicitations = 3 - - tests := []struct { - name string - startFn func(t *testing.T, s *stack.Stack) - // first is used to tell stopFn that it is being called for the first time - // after router solicitations were last enabled. - stopFn func(t *testing.T, s *stack.Stack, first bool) - }{ - // Tests that when forwarding is enabled or disabled, router solicitations - // are stopped or started, respectively. - { - name: "Enable and disable forwarding", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) - } - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - }, - }, - - // Tests that when a NIC is enabled or disabled, router solicitations - // are started or stopped, respectively. - { - name: "Enable and disable NIC", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests that when a NIC is removed, router solicitations are stopped. We - // cannot start router solications on a removed NIC. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack, first bool) { - t.Helper() - - // Only try to remove the NIC the first time stopFn is called since it's - // impossible to remove an already removed NIC. - if !first { - return - } - - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("timed out waiting for packet") - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Stop soliciting routers. - test.stopFn(t, s, true /* first */) - clock.Advance(delay) - if _, ok := e.Read(); ok { - // A single RS may have been sent before solicitations were stopped. - clock.Advance(interval) - if _, ok = e.Read(); ok { - t.Fatal("should not have sent more than one RS message") - } - } - - // Stopping router solicitations after it has already been stopped should - // do nothing. - test.stopFn(t, s, false /* first */) - clock.Advance(delay) - if _, ok := e.Read(); ok { - t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") - } - - // If test.startFn is nil, there is no way to restart router solications. - if test.startFn == nil { - return - } - - // Start soliciting routers. - test.startFn(t, s) - waitForPkt(clock, delay) - waitForPkt(clock, interval) - waitForPkt(clock, interval) - clock.Advance(interval) - if _, ok := e.Read(); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - } - - // Starting router solicitations after it has already completed should do - // nothing. - test.startFn(t, s) - clock.Advance(interval) - if _, ok := e.Read(); ok { - t.Fatal("unexpectedly got a packet after finishing router solicitations") - } - }) - } -} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go deleted file mode 100644 index 7de25fe37..000000000 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ /dev/null @@ -1,1584 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "math/rand" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" -) - -const ( - // entryStoreSize is the default number of entries that will be generated and - // added to the entry store. This number needs to be larger than the size of - // the neighbor cache to give ample opportunity for verifying behavior during - // cache overflows. Four times the size of the neighbor cache allows for - // three complete cache overflows. - entryStoreSize = 4 * neighborCacheSize - - // typicalLatency is the typical latency for an ARP or NDP packet to travel - // to a router and back. - typicalLatency = time.Millisecond - - // testEntryBroadcastAddr is a special address that indicates a packet should - // be sent to all nodes. - testEntryBroadcastAddr = tcpip.Address("broadcast") - - // testEntryBroadcastLinkAddr is a special link address sent back to - // multicast neighbor probes. - testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") - - // infiniteDuration indicates that a task will not occur in our lifetime. - infiniteDuration = time.Duration(math.MaxInt64) -) - -// unorderedEventsDiffOpts returns options passed to cmp.Diff to sort slices of -// events for cases where ordering must be ignored. -func unorderedEventsDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 - }), - } -} - -// unorderedEntriesDiffOpts returns options passed to cmp.Diff to sort slices of -// entries for cases where ordering must be ignored. -func unorderedEntriesDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } -} - -func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { - config.resetInvalidFields() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - linkRes := &testNeighborResolver{ - clock: clock, - entries: newTestEntryStore(), - delay: typicalLatency, - } - linkRes.neigh.init(&nic{ - stack: &Stack{ - clock: clock, - nudDisp: nudDisp, - nudConfigs: config, - randomGenerator: rng, - }, - id: 1, - stats: makeNICStats(tcpip.NICStats{}.FillIn()), - }, linkRes) - return linkRes -} - -// testEntryStore contains a set of IP to NeighborEntry mappings. -type testEntryStore struct { - mu sync.RWMutex - entriesMap map[tcpip.Address]NeighborEntry -} - -func toAddress(i uint16) tcpip.Address { - return tcpip.Address([]byte{ - 1, - 0, - byte(i >> 8), - byte(i), - }) -} - -func toLinkAddress(i uint16) tcpip.LinkAddress { - return tcpip.LinkAddress([]byte{ - 1, - 0, - 0, - 0, - byte(i >> 8), - byte(i), - }) -} - -// newTestEntryStore returns a testEntryStore pre-populated with entries. -func newTestEntryStore() *testEntryStore { - store := &testEntryStore{ - entriesMap: make(map[tcpip.Address]NeighborEntry), - } - for i := uint16(0); i < entryStoreSize; i++ { - addr := toAddress(i) - linkAddr := toLinkAddress(i) - - store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - } - } - return store -} - -// size returns the number of entries in the store. -func (s *testEntryStore) size() uint16 { - s.mu.RLock() - defer s.mu.RUnlock() - return uint16(len(s.entriesMap)) -} - -// entry returns the entry at index i. Returns an empty entry and false if i is -// out of bounds. -func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { - return s.entryByAddr(toAddress(i)) -} - -// entryByAddr returns the entry matching addr for situations when the index is -// not available. Returns an empty entry and false if no entries match addr. -func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - entry, ok := s.entriesMap[addr] - return entry, ok -} - -// entries returns all entries in the store. -func (s *testEntryStore) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(s.entriesMap)) - s.mu.RLock() - defer s.mu.RUnlock() - for i := uint16(0); i < entryStoreSize; i++ { - addr := toAddress(i) - if entry, ok := s.entriesMap[addr]; ok { - entries = append(entries, entry) - } - } - return entries -} - -// set modifies the link addresses of an entry. -func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { - addr := toAddress(i) - s.mu.Lock() - defer s.mu.Unlock() - if entry, ok := s.entriesMap[addr]; ok { - entry.LinkAddr = linkAddr - s.entriesMap[addr] = entry - } -} - -// testNeighborResolver implements LinkAddressResolver to emulate sending a -// neighbor probe. -type testNeighborResolver struct { - clock tcpip.Clock - neigh neighborCache - entries *testEntryStore - delay time.Duration - onLinkAddressRequest func() - dropReplies bool -} - -var _ LinkAddressResolver = (*testNeighborResolver)(nil) - -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { - if !r.dropReplies { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) - } - - // Execute post address resolution action, if available. - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -// fakeRequest emulates handling a response for a link address request. -func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { - if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } -} - -func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == testEntryBroadcastAddr { - return testEntryBroadcastLinkAddr, true - } - return "", false -} - -func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 0 -} - -func TestNeighborCacheGetConfig(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - if got, want := linkRes.neigh.config(), c; got != want { - t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheSetConfig(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - linkRes.neigh.setConfig(c) - - if got, want := linkRes.neigh.config(), c; got != want { - t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry, removed []NeighborEntry) error { - var gotLinkResolutionResult LinkResolutionResult - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - gotLinkResolutionResult = r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - { - var wantEvents []testEntryEventInfo - - for _, removedEntry := range removed { - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }) - } - - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }) - - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - clock.Advance(typicalLatency) - - select { - case <-ch: - default: - return fmt.Errorf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - wantLinkResolutionResult := LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil} - if diff := cmp.Diff(wantLinkResolutionResult, gotLinkResolutionResult); diff != "" { - return fmt.Errorf("got link resolution result mismatch (-want +got):\n%s", diff) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func addReachableEntry(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry) error { - return addReachableEntryWithRemoved(nudDisp, clock, linkRes, entry, nil /* removed */) -} - -func TestNeighborCacheEntry(t *testing.T) { - c := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { - t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - - // No more events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheRemoveEntry(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - } -} - -type testContext struct { - clock *faketime.ManualClock - linkRes *testNeighborResolver - nudDisp *testNUDDispatcher -} - -func newTestContext(c NUDConfigurations) testContext { - nudDisp := &testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nudDisp, c, clock) - - return testContext{ - clock: clock, - linkRes: linkRes, - nudDisp: nudDisp, - } -} - -type overflowOptions struct { - startAtEntryIndex uint16 - wantStaticEntries []NeighborEntry -} - -func (c *testContext) overflowCache(opts overflowOptions) error { - // Fill the neighbor cache to capacity to verify the LRU eviction strategy is - // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { - var removedEntries []NeighborEntry - - // When beyond the full capacity, the cache will evict an entry as per the - // LRU eviction strategy. Note that the number of static entries should not - // affect the total number of dynamic entries that can be added. - if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) - if !ok { - return fmt.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize) - } - removedEntries = append(removedEntries, removedEntry) - } - - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) - } - if err := addReachableEntryWithRemoved(c.nudDisp, c.clock, c.linkRes, entry, removedEntries); err != nil { - return fmt.Errorf("addReachableEntryWithRemoved(...) = %s", err) - } - } - - // Expect to find only the most recent entries. The order of entries reported - // by entries() is nondeterministic, so entries have to be sorted before - // comparison. - wantUnorderedEntries := opts.wantStaticEntries - for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) - } - durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now().Add(-durationReachableNanos), - } - wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) - } - - if diff := cmp.Diff(wantUnorderedEntries, c.linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } - - // No more events should have been dispatched. - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy -// respects the dynamic entry count. -func TestNeighborCacheOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache -// eviction strategy respects the dynamic entry count when an entry is removed. -func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Remove the entry - c.linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that -// adding a duplicate static entry with the same link address does not dispatch -// any events. -func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { - config := DefaultNUDConfigurations() - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Add a duplicate static entry with the same link address. - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that -// adding a duplicate static entry with a different link address dispatches a -// change event. -func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { - config := DefaultNUDConfigurations() - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Add a duplicate entry with a different link address - staticLinkAddr += "duplicate" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } -} - -// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache -// eviction strategy respects the dynamic entry count when a static entry is -// added then removed. In this case, the dynamic entry count shouldn't have -// been touched. -func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Remove the static entry that was just added - c.linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU -// cache eviction strategy keeps count of the dynamic entry count when an entry -// is overwritten by a static entry. Static entries should not count towards -// the size of the LRU cache. -func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Override the entry with a static one using the same address - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now(), - }, - }, - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 1, - wantStaticEntries: []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - } - if diff := cmp.Diff(want, e); diff != "" { - t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 1, - wantStaticEntries: []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheClear(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - // Add a dynamic entry. - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Add a static entry. - linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Clear should remove both dynamic and static entries. - linkRes.neigh.clear() - - // Remove events dispatched from clear() have no deterministic order so they - // need to be sorted before comparison. - wantUnorderedEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(wantUnorderedEvents, nudDisp.mu.events, unorderedEventsDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction -// strategy keeps count of the dynamic entry count when all entries are -// cleared. -func TestNeighborCacheClearThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Clear the cache. - c.linkRes.neigh.clear() - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - startedAt := clock.Now() - - // The following logic is very similar to overflowCache, but - // periodically refreshes the frequently used entry. - - // Fill the neighbor cache to capacity - for i := uint16(0); i < neighborCacheSize; i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - } - - frequentlyUsedEntry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - // Keep adding more entries - for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { - // Periodically refresh the frequently used entry - if i%(neighborCacheSize/2) == 0 { - if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) - } - } - - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - - // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize+1) - } - - if err := addReachableEntryWithRemoved(&nudDisp, clock, linkRes, entry, []NeighborEntry{removedEntry}); err != nil { - t.Fatalf("addReachableEntryWithRemoved(...) = %s", err) - } - } - - // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is nondeterministic, so entries - // have to be sorted before comparison. - wantUnsortedEntries := []NeighborEntry{ - { - Addr: frequentlyUsedEntry.Addr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, - // Can be inferred since the frequently used entry is the first to - // be created and transitioned to Reachable. - UpdatedAt: startedAt.Add(typicalLatency), - }, - } - - for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency - wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now().Add(-durationReachableNanos), - }) - } - - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { - t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } - - // No more events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheConcurrent(t *testing.T) { - const concurrentProcesses = 16 - - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - storeEntries := linkRes.entries.entries() - for _, entry := range storeEntries { - var wg sync.WaitGroup - for r := 0; r < concurrentProcesses; r++ { - wg.Add(1) - go func(entry NeighborEntry) { - defer wg.Done() - switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) - } - }(entry) - } - - // Wait for all goroutines to send a request - wg.Wait() - - // Process all the requests for a single entry concurrently - clock.Advance(typicalLatency) - } - - // All goroutines add in the same order and add more values than can fit in - // the cache. Our eviction strategy requires that the last entries are - // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is nondeterministic, so entries - // have to be sorted before comparison. - var wantUnsortedEntries []NeighborEntry - for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency - wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now().Add(-durationReachableNanos), - }) - } - - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { - t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheReplace(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Notify of a link address change - var updatedLinkAddr tcpip.LinkAddress - { - entry, ok := linkRes.entries.entry(1) - if !ok { - t.Fatal("got linkRes.entries.entry(1) = _, false, want = true") - } - updatedLinkAddr = entry.LinkAddr - } - linkRes.entries.set(0, updatedLinkAddr) - linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - - // Requesting the entry again should start neighbor reachability confirmation. - // - // Verify the entry's new link address and the new state. - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(want, e); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - } - - clock.Advance(config.DelayFirstProbeTime + typicalLatency) - - // Verify that the neighbor is now reachable. - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(want, e); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - } -} - -func TestNeighborCacheResolutionFailed(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - // First, sanity check that resolution is working - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - - // Verify address resolution fails for an unknown address. - before := atomic.LoadUint32(&requestCount) - - entry.Addr += "2" - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - } - - maxAttempts := linkRes.neigh.config().MaxUnicastProbes - if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes -// probes and not retrieving a confirmation before the duration defined by -// MaxMulticastProbes * RetransmitTimer. -func TestNeighborCacheResolutionTimeout(t *testing.T) { - config := DefaultNUDConfigurations() - config.RetransmitTimer = time.Millisecond // small enough to cause timeout - - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nil, config, clock) - // large enough to cause timeout - linkRes.delay = time.Minute - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } -} - -// TestNeighborCacheRetryResolution simulates retrying communication after -// failing to perform address resolution. -func TestNeighborCacheRetryResolution(t *testing.T) { - config := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - // Simulate a faulty link. - linkRes.dropReplies = true - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - // Perform address resolution with a faulty link, which will fail. - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - wantEntries := []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAt: clock.Now(), - }, - } - if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { - t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) - } - } - } - - // Retry address resolution with a working link. - linkRes.dropReplies = false - { - incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - if incompleteEntry.State != Incomplete { - t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - clock.Advance(typicalLatency) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - gotEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', _): %s", entry.Addr, err) - } - - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { - t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) - } - } - } -} - -func BenchmarkCacheClear(b *testing.B) { - b.StopTimer() - config := DefaultNUDConfigurations() - clock := tcpip.NewStdClock() - linkRes := newTestNeighborResolver(nil, config, clock) - linkRes.delay = 0 - - // Clear for every possible size of the cache - for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { - // Fill the neighbor cache to capacity. - for i := uint16(0); i < cacheSize; i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - select { - case <-ch: - default: - b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - } - - b.StartTimer() - linkRes.neigh.clear() - b.StopTimer() - } -} diff --git a/pkg/tcpip/stack/neighbor_entry_list.go b/pkg/tcpip/stack/neighbor_entry_list.go new file mode 100644 index 000000000..d78430080 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type neighborEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (neighborEntryElementMapper) linkerFor(elem *neighborEntry) *neighborEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type neighborEntryList struct { + head *neighborEntry + tail *neighborEntry +} + +// Reset resets list l to the empty state. +func (l *neighborEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *neighborEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *neighborEntryList) Front() *neighborEntry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *neighborEntryList) Back() *neighborEntry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *neighborEntryList) Len() (count int) { + for e := l.Front(); e != nil; e = (neighborEntryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *neighborEntryList) PushFront(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + neighborEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *neighborEntryList) PushBack(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *neighborEntryList) PushBackList(m *neighborEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + neighborEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *neighborEntryList) InsertAfter(b, e *neighborEntry) { + bLinker := neighborEntryElementMapper{}.linkerFor(b) + eLinker := neighborEntryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + neighborEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *neighborEntryList) InsertBefore(a, e *neighborEntry) { + aLinker := neighborEntryElementMapper{}.linkerFor(a) + eLinker := neighborEntryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + neighborEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *neighborEntryList) Remove(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + neighborEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + neighborEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type neighborEntryEntry struct { + next *neighborEntry + prev *neighborEntry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) Next() *neighborEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) Prev() *neighborEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) SetNext(elem *neighborEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) SetPrev(elem *neighborEntry) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go deleted file mode 100644 index 59d86d6d4..000000000 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ /dev/null @@ -1,2269 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "math/rand" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - - entryTestNICID tcpip.NICID = 1 - - entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") - entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") -) - -var ( - entryTestAddr1 = testutil.MustParse6("a::1") - entryTestAddr2 = testutil.MustParse6("a::2") -) - -// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current -// time. -func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { - clock.Advance(immediateDuration) -} - -// The following unit tests exercise every state transition and verify its -// behavior with RFC 4681 and RFC 7048. -// -// | From | To | Cause | Update | Action | Event | -// | =========== | =========== | ========================================== | ======== | ===========| ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | -// | Unknown | Stale | Probe | | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | -// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | -// | Reachable | Stale | Reachable timer expired | | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | -// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Stale | Stale | Override confirmation | LinkAddr | | Changed | -// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | -// | Stale | Delay | Packet sent | | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | | Changed | -// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | -// | Delay | Probe | Delay timer expired | | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | -// | Probe | Probe | Retransmit timer expired | | | Changed | -// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed | -// | Unreachable | Incomplete | Packet queued | | Send probe | Changed | -// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed | - -type testEntryEventType uint8 - -const ( - entryTestAdded testEntryEventType = iota - entryTestChanged - entryTestRemoved -) - -func (t testEntryEventType) String() string { - switch t { - case entryTestAdded: - return "add" - case entryTestChanged: - return "change" - case entryTestRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -// Fields are exported for use with cmp.Diff. -type testEntryEventInfo struct { - EventType testEntryEventType - NICID tcpip.NICID - Entry NeighborEntry -} - -func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) -} - -// testNUDDispatcher implements NUDDispatcher to validate the dispatching of -// events upon certain NUD state machine events. -type testNUDDispatcher struct { - mu struct { - sync.Mutex - events []testEntryEventInfo - } -} - -var _ NUDDispatcher = (*testNUDDispatcher)(nil) - -func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { - d.mu.Lock() - defer d.mu.Unlock() - d.mu.events = append(d.mu.events, e) -} - -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestAdded, - NICID: nicID, - Entry: entry, - }) -} - -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestChanged, - NICID: nicID, - Entry: entry, - }) -} - -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: nicID, - Entry: entry, - }) -} - -type entryTestLinkResolver struct { - mu struct { - sync.Mutex - probes []entryTestProbeInfo - } -} - -var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) - -type entryTestProbeInfo struct { - RemoteAddress tcpip.Address - RemoteLinkAddress tcpip.LinkAddress - LocalAddress tcpip.Address -} - -func (p entryTestProbeInfo) String() string { - return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) -} - -// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts -// to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { - r.mu.Lock() - defer r.mu.Unlock() - r.mu.probes = append(r.mu.probes, entryTestProbeInfo{ - RemoteAddress: targetAddr, - RemoteLinkAddress: linkAddr, - LocalAddress: localAddr, - }) - return nil -} - -// ResolveStaticAddress attempts to resolve address without sending requests. -// It either resolves the name immediately or returns the empty LinkAddress. -func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { - return "", false -} - -// LinkAddressProtocol returns the network protocol of the addresses this -// resolver can resolve. -func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return entryTestNetNumber -} - -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { - clock := faketime.NewManualClock() - disp := testNUDDispatcher{} - nic := nic{ - LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint - - id: entryTestNICID, - stack: &Stack{ - clock: clock, - nudDisp: &disp, - nudConfigs: c, - randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), - }, - stats: makeNICStats(tcpip.NICStats{}.FillIn()), - } - netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) - nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: netEP, - } - - var linkRes entryTestLinkResolver - // Stub out the neighbor cache to verify deletion from the cache. - l := &linkResolver{ - resolver: &linkRes, - } - l.neigh.init(&nic, &linkRes) - - entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state) - l.neigh.mu.Lock() - l.neigh.mu.cache[entryTestAddr1] = entry - l.neigh.mu.Unlock() - nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{ - header.IPv6ProtocolNumber: l, - } - - return entry, &disp, &linkRes, clock -} - -// TestEntryInitiallyUnknown verifies that the state of a newly created -// neighborEntry is Unknown. -func TestEntryInitiallyUnknown(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - if e.mu.neigh.State != Unknown { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.mu.Unlock() - - clock.Advance(c.RetransmitTimer) - - // No probes should have been sent. - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Unknown { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.mu.Unlock() - - clock.Advance(time.Hour) - - // No probes should have been sent. - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnknownToIncomplete(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } -} - -func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Unknown { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - return nil - }(); err != nil { - return err - } - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - return nil -} - -func TestEntryUnknownToStale(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } -} - -func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Unknown { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - - // UpdatedAt should remain the same during address resolution. - e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAt - e.mu.Unlock() - - // Wait for the rest of the reachability probe transmissions, signifying - // Incomplete to Incomplete transitions. - for i := uint32(1); i < c.MaxMulticastProbes; i++ { - clock.Advance(c.RetransmitTimer) - - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { - t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) - } - e.mu.Unlock() - } - - // UpdatedAt should change after failing address resolution. Timing out after - // sending the last probe transitions the entry to Unreachable. - clock.Advance(c.RetransmitTimer) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToReachable(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } -} - -func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, flags ReachabilityConfirmationFlags) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.handleConfirmationLocked(entryTestLinkAddr1, flags) - if e.mu.neigh.State != Reachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - return fmt.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func incompleteToReachable(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := incompleteToReachableWithFlags(e, nudDisp, linkRes, clock, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }); err != nil { - return err - } - - e.mu.Lock() - isRouter := e.mu.isRouter - e.mu.Unlock() - if isRouter { - return fmt.Errorf("got e.mu.isRouter = %t, want = false", isRouter) - } - - return nil -} - -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachableWithRouterFlag(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachableWithRouterFlag(...) = %s", err) - } -} - -func incompleteToReachableWithRouterFlag(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := incompleteToReachableWithFlags(e, nudDisp, linkRes, clock, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }); err != nil { - return err - } - - e.mu.Lock() - isRouter := e.mu.isRouter - e.mu.Unlock() - if !isRouter { - return fmt.Errorf("got e.mu.isRouter = %t, want = true", isRouter) - } - - return nil -} - -func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.isRouter { - t.Errorf("got e.mu.isRouter = %t, want = false", e.mu.isRouter) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToUnreachable(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToUnreachable(...) = %s", err) - } -} - -func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Incomplete) - } - } - - // The first probe was sent in the transition from Unknown to Incomplete. - clock.Advance(c.RetransmitTimer) - - // Observe each subsequent multicast probe transmitted. - for i := uint32(1); i < c.MaxMulticastProbes; i++ { - wantProbes := []entryTestProbeInfo{{ - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: "", - LocalAddress: entryTestAddr2, - }} - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) - } - - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Incomplete) - } - - clock.Advance(c.RetransmitTimer) - } - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Unreachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Unreachable) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -type testLocker struct{} - -var _ sync.Locker = (*testLocker)(nil) - -func (*testLocker) Lock() {} -func (*testLocker) Unlock() {} - -func TestEntryReachableToReachableClearsRouterWhenConfirmationWithoutRouter(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachableWithRouterFlag(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachableWithRouterFlag(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if got, want := e.mu.isRouter, false; got != want { - t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) - } - ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) - if ipv6EP.invalidatedRtr != e.mu.neigh.Addr { - t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestEntryReachableToReachableWhenProbeWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestEntryReachableToStaleWhenTimeout(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - if err := reachableToStale(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("reachableToStale(...) = %s", err) - } -} - -// reachableToStale transitions a neighborEntry in Reachable state to Stale -// state. Depends on the elimination of random factors in the ReachableTime -// computation. -// -// c.MinRandomFactor = 1 -// c.MaxRandomFactor = 1 -func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - // Ensure there are no random factors in the ReachableTime computation. - if c.MinRandomFactor != 1 { - return fmt.Errorf("got c.MinRandomFactor = %f, want = 1", c.MinRandomFactor) - } - if c.MaxRandomFactor != 1 { - return fmt.Errorf("got c.MaxRandomFactor = %f, want = 1", c.MaxRandomFactor) - } - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Reachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Reachable) - } - } - - clock.Advance(c.BaseReachableTime) - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Stale { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Stale) - } - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - { - - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenProbeWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToDelay(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } -} - -func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Stale { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleUpperLevelConfirmationLocked() - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToProbe(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } -} - -func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Delay { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Delay) - } - } - - // Wait for the first unicast probe to be transmitted, marking the - // transition from Delay to Probe. - clock.Advance(c.DelayFirstProbeTime) - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe) - } - } - - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - { - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToReachableWithOverride(e, nudDisp, linkRes, clock, entryTestLinkAddr2); err != nil { - t.Fatalf("probeToReachableWithOverride(...) = %s", err) - } -} - -func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - prevLinkAddr := e.mu.neigh.LinkAddr - if e.mu.neigh.State != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(linkAddr, flags) - - if e.mu.neigh.State != Reachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if linkAddr == "" { - linkAddr = prevLinkAddr - } - if e.mu.neigh.LinkAddr != linkAddr { - return fmt.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, linkAddr) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func probeToReachableWithOverride(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, linkAddr tcpip.LinkAddress) error { - return probeToReachableWithFlags(e, nudDisp, linkRes, clock, linkAddr, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) -} - -func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("probeToReachable(...) = %s", err) - } -} - -func probeToReachable(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - return probeToReachableWithFlags(e, nudDisp, linkRes, clock, entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) -} - -func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToReachableWithoutAddress(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("probeToReachableWithoutAddress(...) = %s", err) - } -} - -func probeToReachableWithoutAddress(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - return probeToReachableWithFlags(e, nudDisp, linkRes, clock, "" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) -} - -func TestEntryProbeToUnreachable(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 - c.DelayFirstProbeTime = c.RetransmitTimer - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("probeToUnreachable(...) = %s", err) - } -} - -func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe) - } - } - - // The first probe was sent in the transition from Delay to Probe. - clock.Advance(c.RetransmitTimer) - - // Observe each subsequent unicast probe transmitted. - for i := uint32(1); i < c.MaxUnicastProbes; i++ { - wantProbes := []entryTestProbeInfo{{ - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }} - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) - } - - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe) - } - - clock.Advance(c.RetransmitTimer) - } - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Unreachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Unreachable) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -func TestEntryUnreachableToIncomplete(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToUnreachable(...) = %s", err) - } - if err := unreachableToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unreachableToIncomplete(...) = %s", err) - } -} - -func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Unreachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - return nil - }(); err != nil { - return err - } - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - return nil -} - -func TestEntryUnreachableToStale(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToUnreachable(...) = %s", err) - } - if err := unreachableToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unreachableToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - if err := reachableToStale(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("reachableToStale(...) = %s", err) - } -} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go deleted file mode 100644 index 5cb342f78..000000000 --- a/pkg/tcpip/stack/nic_test.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) -var _ NDPEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - AddressableEndpointState - - nic NetworkInterface - protocol *testIPv6Protocol - - invalidatedRtr tcpip.Address -} - -func (*testIPv6Endpoint) Enable() tcpip.Error { - return nil -} - -func (*testIPv6Endpoint) Enabled() bool { - return true -} - -func (*testIPv6Endpoint) Disable() {} - -// DefaultTTL implements NetworkEndpoint.DefaultTTL. -func (*testIPv6Endpoint) DefaultTTL() uint8 { - return 0 -} - -// MTU implements NetworkEndpoint.MTU. -func (e *testIPv6Endpoint) MTU() uint32 { - return e.nic.MTU() - header.IPv6MinimumSize -} - -// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. -func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.IPv6MinimumSize -} - -// WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error { - return nil -} - -// WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, &tcpip.ErrNotSupported{} -} - -// WriteHeaderIncludedPacket implements -// NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { - // Our tests don't use this so we don't support it. - return &tcpip.ErrNotSupported{} -} - -// HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} - -// Close implements NetworkEndpoint.Close. -func (e *testIPv6Endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} - -// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. -func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { - e.invalidatedRtr = rtr -} - -// Stats implements NetworkEndpoint. -func (*testIPv6Endpoint) Stats() NetworkEndpointStats { - return &testIPv6EndpointStats{} -} - -var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil) - -type testIPv6EndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} - -// We use this instead of ipv6.protocol because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Protocol struct{} - -// Number implements NetworkProtocol.Number. -func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. -func (*testIPv6Protocol) MinimumPacketSize() int { - return header.IPv6MinimumSize -} - -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - -// ParseAddresses implements NetworkProtocol.ParseAddresses. -func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.IPv6(v) - return h.SourceAddress(), h.DestinationAddress() -} - -// NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { - e := &testIPv6Endpoint{ - nic: nic, - protocol: p, - } - e.AddressableEndpointState.Init(e) - return e -} - -// SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { - return nil -} - -// Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { - return nil -} - -// Close implements NetworkProtocol.Close. -func (*testIPv6Protocol) Close() {} - -// Wait implements NetworkProtocol.Wait. -func (*testIPv6Protocol) Wait() {} - -// Parse implements NetworkProtocol.Parse. -func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - return 0, false, false -} - -func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { - // When the NIC is disabled, the only field that matters is the stats field. - // This test is limited to stats counter checks. - nic := nic{ - stats: makeNICStats(tcpip.NICStats{}.FillIn()), - } - - if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { - t.Errorf("got DisabledRx.Packets = %d, want = 0", got) - } - if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { - t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) - } - if got := nic.stats.local.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ - Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), - })) - - if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { - t.Errorf("got DisabledRx.Packets = %d, want = 1", got) - } - if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { - t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) - } - if got := nic.stats.local.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - global := tcpip.NICStats{}.FillIn() - nic := nic{ - stats: makeNICStats(global), - } - multi := nic.stats.multiCounterNICStats - local := nic.stats.local - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go deleted file mode 100644 index 1aeb2f8a5..000000000 --- a/pkg/tcpip/stack/nud_test.go +++ /dev/null @@ -1,816 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "math" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - - defaultFakeRandomNum = 0.5 -) - -// fakeRand is a deterministic random number generator. -type fakeRand struct { - num float32 -} - -var _ rand.Source = (*fakeRand)(nil) - -func (f *fakeRand) Int63() int64 { - return int64(f.num * float32(1<<63)) -} - -func (*fakeRand) Seed(int64) {} - -func TestNUDFunctions(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicID tcpip.NICID - netProtoFactory []stack.NetworkProtocolFactory - extraLinkCapabilities stack.LinkEndpointCapabilities - expectedErr tcpip.Error - }{ - { - name: "Invalid NICID", - nicID: nicID + 1, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - expectedErr: &tcpip.ErrUnknownNICID{}, - }, - { - name: "No network protocol", - nicID: nicID, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With IPv6", - nicID: nicID, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With resolution capability", - nicID: nicID, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With IPv6 and resolution capability", - nicID: nicID, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - NetworkProtocols: test.netProtoFactory, - Clock: clock, - }) - - e := channel.New(0, 0, linkAddr1) - e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired - e.LinkEPCapabilities |= test.extraLinkCapabilities - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - configs := stack.DefaultNUDConfigurations() - configs.BaseReachableTime = time.Hour - - { - err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } - } - - { - gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if diff := cmp.Diff(configs, gotConfigs); diff != "" { - t.Errorf("got configs mismatch (-want +got):\n%s", diff) - } - } - } - - for _, addr := range []tcpip.Address{llAddr1, llAddr2} { - { - err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) - } - } - } - - { - wantErr := test.expectedErr - for i := 0; i < 2; i++ { - { - err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) - if diff := cmp.Diff(wantErr, err); diff != "" { - t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } - } - - if test.expectedErr != nil { - break - } - - // Removing a neighbor that does not exist should give us a bad address - // error. - wantErr = &tcpip.ErrBadAddress{} - } - } - - { - neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, - neighbors, - ); diff != "" { - t.Errorf("neighbors mismatch (-want +got):\n%s", diff) - } - } - } - - { - err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { - t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) - } else if len(neighbors) != 0 { - t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) - } - } - } - }) - } -} - -func TestDefaultNUDConfigurations(t *testing.T) { - const nicID = 1 - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: stack.DefaultNUDConfigurations(), - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) - } -} - -func TestNUDConfigurationsBaseReachableTime(t *testing.T) { - tests := []struct { - name string - baseReachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - baseReachableTime: 0, - want: defaultBaseReachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - baseReachableTime: time.Millisecond, - want: time.Millisecond, - }, - { - name: "MoreThanDefaultBaseReachableTime", - baseReachableTime: 2 * defaultBaseReachableTime, - want: 2 * defaultBaseReachableTime, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.BaseReachableTime = test.baseReachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.BaseReachableTime; got != test.want { - t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMinRandomFactor(t *testing.T) { - tests := []struct { - name string - minRandomFactor float32 - want float32 - }{ - // Invalid cases - { - name: "LessThanZero", - minRandomFactor: -1, - want: defaultMinRandomFactor, - }, - { - name: "EqualToZero", - minRandomFactor: 0, - want: defaultMinRandomFactor, - }, - // Valid cases - { - name: "MoreThanZero", - minRandomFactor: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MinRandomFactor = test.minRandomFactor - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MinRandomFactor; got != test.want { - t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { - tests := []struct { - name string - minRandomFactor float32 - maxRandomFactor float32 - want float32 - }{ - // Invalid cases - { - name: "LessThanZero", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: -1, - want: defaultMaxRandomFactor, - }, - { - name: "EqualToZero", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: 0, - want: defaultMaxRandomFactor, - }, - { - name: "LessThanMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor * 0.99, - want: defaultMaxRandomFactor, - }, - { - name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", - minRandomFactor: defaultMaxRandomFactor * 2, - maxRandomFactor: defaultMaxRandomFactor, - want: defaultMaxRandomFactor * 6, - }, - // Valid cases - { - name: "EqualToMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor, - want: defaultMinRandomFactor, - }, - { - name: "MoreThanMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor * 1.1, - want: defaultMinRandomFactor * 1.1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MinRandomFactor = test.minRandomFactor - c.MaxRandomFactor = test.maxRandomFactor - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxRandomFactor; got != test.want { - t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsRetransmitTimer(t *testing.T) { - tests := []struct { - name string - retransmitTimer time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - retransmitTimer: 0, - want: defaultRetransmitTimer, - }, - { - name: "LessThanMinimumRetransmitTimer", - retransmitTimer: minimumRetransmitTimer - time.Nanosecond, - want: defaultRetransmitTimer, - }, - // Valid cases - { - name: "EqualToMinimumRetransmitTimer", - retransmitTimer: minimumRetransmitTimer, - want: minimumBaseReachableTime, - }, - { - name: "LargetThanMinimumRetransmitTimer", - retransmitTimer: 2 * minimumBaseReachableTime, - want: 2 * minimumBaseReachableTime, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.RetransmitTimer = test.retransmitTimer - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.RetransmitTimer; got != test.want { - t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { - tests := []struct { - name string - delayFirstProbeTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - delayFirstProbeTime: 0, - want: defaultDelayFirstProbeTime, - }, - // Valid cases - { - name: "MoreThanZero", - delayFirstProbeTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.DelayFirstProbeTime = test.delayFirstProbeTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.DelayFirstProbeTime; got != test.want { - t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { - tests := []struct { - name string - maxMulticastProbes uint32 - want uint32 - }{ - // Invalid cases - { - name: "EqualToZero", - maxMulticastProbes: 0, - want: defaultMaxMulticastProbes, - }, - // Valid cases - { - name: "MoreThanZero", - maxMulticastProbes: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MaxMulticastProbes = test.maxMulticastProbes - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxMulticastProbes; got != test.want { - t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { - tests := []struct { - name string - maxUnicastProbes uint32 - want uint32 - }{ - // Invalid cases - { - name: "EqualToZero", - maxUnicastProbes: 0, - want: defaultMaxUnicastProbes, - }, - // Valid cases - { - name: "MoreThanZero", - maxUnicastProbes: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MaxUnicastProbes = test.maxUnicastProbes - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxUnicastProbes; got != test.want { - t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) - } - }) - } -} - -// TestNUDStateReachableTime verifies the correctness of the ReachableTime -// computation. -func TestNUDStateReachableTime(t *testing.T) { - tests := []struct { - name string - baseReachableTime time.Duration - minRandomFactor float32 - maxRandomFactor float32 - want time.Duration - }{ - { - name: "AllZeros", - baseReachableTime: 0, - minRandomFactor: 0, - maxRandomFactor: 0, - want: 0, - }, - { - name: "ZeroMaxRandomFactor", - baseReachableTime: time.Second, - minRandomFactor: 0, - maxRandomFactor: 0, - want: 0, - }, - { - name: "ZeroMinRandomFactor", - baseReachableTime: time.Second, - minRandomFactor: 0, - maxRandomFactor: 1, - want: time.Duration(defaultFakeRandomNum * float32(time.Second)), - }, - { - name: "FractionalRandomFactor", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 0.001, - maxRandomFactor: 0.002, - want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), - }, - { - name: "MinAndMaxRandomFactorsEqual", - baseReachableTime: time.Second, - minRandomFactor: 1, - maxRandomFactor: 1, - want: time.Second, - }, - { - name: "MinAndMaxRandomFactorsDifferent", - baseReachableTime: time.Second, - minRandomFactor: 1, - maxRandomFactor: 2, - want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), - }, - { - name: "MaxInt64", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 1, - maxRandomFactor: 1, - want: time.Duration(math.MaxInt64), - }, - { - name: "Overflow", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 1.5, - maxRandomFactor: 1.5, - want: time.Duration(math.MaxInt64), - }, - { - name: "DoubleOverflow", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 2.5, - maxRandomFactor: 2.5, - want: time.Duration(math.MaxInt64), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := stack.NUDConfigurations{ - BaseReachableTime: test.baseReachableTime, - MinRandomFactor: test.minRandomFactor, - MaxRandomFactor: test.maxRandomFactor, - } - // A fake random number generator is used to ensure deterministic - // results. - rng := fakeRand{ - num: defaultFakeRandomNum, - } - var clock faketime.NullClock - s := stack.NewNUDState(c, &clock, rand.New(&rng)) - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - }) - } -} - -// TestNUDStateRecomputeReachableTime exercises the ReachableTime function -// twice to verify recomputation of reachable time when the min random factor, -// max random factor, or base reachable time changes. -func TestNUDStateRecomputeReachableTime(t *testing.T) { - const defaultBase = time.Second - const defaultMin = 2.0 * defaultMaxRandomFactor - const defaultMax = 3.0 * defaultMaxRandomFactor - - tests := []struct { - name string - baseReachableTime time.Duration - minRandomFactor float32 - maxRandomFactor float32 - want time.Duration - }{ - { - name: "BaseReachableTime", - baseReachableTime: 2 * defaultBase, - minRandomFactor: defaultMin, - maxRandomFactor: defaultMax, - want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), - }, - { - name: "MinRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: defaultMax, - maxRandomFactor: defaultMax, - want: time.Duration(defaultMax * float32(defaultBase)), - }, - { - name: "MaxRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: defaultMin, - maxRandomFactor: defaultMin, - want: time.Duration(defaultMin * float32(defaultBase)), - }, - { - name: "BothRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: 2 * defaultMin, - maxRandomFactor: 2 * defaultMax, - want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), - }, - { - name: "BaseReachableTimeAndBothRandomFactors", - baseReachableTime: 2 * defaultBase, - minRandomFactor: 2 * defaultMin, - maxRandomFactor: 2 * defaultMax, - want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := stack.DefaultNUDConfigurations() - c.BaseReachableTime = defaultBase - c.MinRandomFactor = defaultMin - c.MaxRandomFactor = defaultMax - - // A fake random number generator is used to ensure deterministic - // results. - rng := fakeRand{ - num: defaultFakeRandomNum, - } - var clock faketime.NullClock - s := stack.NewNUDState(c, &clock, rand.New(&rng)) - old := s.ReachableTime() - - if got, want := s.ReachableTime(), old; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - - // Check for recomputation when changing the min random factor, the max - // random factor, the base reachability time, or any permutation of those - // three options. - c.BaseReachableTime = test.baseReachableTime - c.MinRandomFactor = test.minRandomFactor - c.MaxRandomFactor = test.maxRandomFactor - s.SetConfig(c) - - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - - // Verify that ReachableTime isn't recomputed when none of the - // configuration options change. The random factor is changed so that if - // a recompution were to occur, ReachableTime would change. - rng.num = defaultFakeRandomNum / 2.0 - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - }) - } -} diff --git a/pkg/tcpip/stack/packet_buffer_list.go b/pkg/tcpip/stack/packet_buffer_list.go new file mode 100644 index 000000000..ce7057d6b --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type PacketBufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (PacketBufferElementMapper) linkerFor(elem *PacketBuffer) *PacketBuffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type PacketBufferList struct { + head *PacketBuffer + tail *PacketBuffer +} + +// Reset resets list l to the empty state. +func (l *PacketBufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *PacketBufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *PacketBufferList) Front() *PacketBuffer { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *PacketBufferList) Back() *PacketBuffer { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *PacketBufferList) Len() (count int) { + for e := l.Front(); e != nil; e = (PacketBufferElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *PacketBufferList) PushFront(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + PacketBufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *PacketBufferList) PushBack(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *PacketBufferList) PushBackList(m *PacketBufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + PacketBufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *PacketBufferList) InsertAfter(b, e *PacketBuffer) { + bLinker := PacketBufferElementMapper{}.linkerFor(b) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + PacketBufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *PacketBufferList) InsertBefore(a, e *PacketBuffer) { + aLinker := PacketBufferElementMapper{}.linkerFor(a) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + PacketBufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *PacketBufferList) Remove(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + PacketBufferElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + PacketBufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type PacketBufferEntry struct { + next *PacketBuffer + prev *PacketBuffer +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) Next() *PacketBuffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) Prev() *PacketBuffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) SetNext(elem *PacketBuffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) SetPrev(elem *PacketBuffer) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go deleted file mode 100644 index 87b023445..000000000 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ /dev/null @@ -1,675 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "bytes" - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestPacketHeaderPush(t *testing.T) { - for _, test := range []struct { - name string - reserved int - link []byte - network []byte - transport []byte - data []byte - }{ - { - name: "construct empty packet", - }, - { - name: "construct link header only packet", - reserved: 60, - link: makeView(10), - }, - { - name: "construct link and network header only packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - }, - { - name: "construct header only packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - transport: makeView(30), - }, - { - name: "construct data only packet", - data: makeView(40), - }, - { - name: "construct L3 packet", - reserved: 60, - network: makeView(20), - transport: makeView(30), - data: makeView(40), - }, - { - name: "construct L2 packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - transport: makeView(30), - data: makeView(40), - }, - } { - t.Run(test.name, func(t *testing.T) { - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: test.reserved, - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), - }) - - allHdrSize := len(test.link) + len(test.network) + len(test.transport) - - // Check the initial values for packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - ReserveHeaderBytes: test.reserved, - Data: buffer.View(test.data).ToVectorisedView(), - }) - - // Push headers. - if v := test.transport; len(v) > 0 { - copy(pk.TransportHeader().Push(len(v)), v) - } - if v := test.network; len(v) > 0 { - copy(pk.NetworkHeader().Push(len(v)), v) - } - if v := test.link; len(v) > 0 { - copy(pk.LinkHeader().Push(len(v)), v) - } - - // Check the after values for packet. - if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { - t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { - t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), allHdrSize; got != want { - t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) - } - if got, want := pk.Size(), allHdrSize+len(test.data); got != want { - t.Errorf("After pk.Size() = %d, want %d", got, want) - } - // Check the after state. - checkPacketContents(t, "After ", pk, packetContents{ - link: test.link, - network: test.network, - transport: test.transport, - data: test.data, - }) - }) - } -} - -func TestPacketBufferClone(t *testing.T) { - data := concatViews(makeView(20), makeView(30), makeView(40)) - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - bytesToDelete := 30 - originalSize := data.Size() - - clonedPks := []*PacketBuffer{ - pk.Clone(), - pk.CloneToInbound(), - } - pk.Data().DeleteFront(bytesToDelete) - if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { - t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) - } - for _, clonedPk := range clonedPks { - if got := clonedPk.Data().Size(); got != originalSize { - t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) - } - } -} - -func TestPacketHeaderConsume(t *testing.T) { - for _, test := range []struct { - name string - data []byte - link int - network int - transport int - }{ - { - name: "parse L2 packet", - data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), - link: 10, - network: 20, - transport: 30, - }, - { - name: "parse L3 packet", - data: concatViews(makeView(20), makeView(30), makeView(40)), - network: 20, - transport: 30, - }, - } { - t.Run(test.name, func(t *testing.T) { - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), - }) - - // Check the initial values for packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - Data: buffer.View(test.data).ToVectorisedView(), - }) - - // Consume headers. - if size := test.link; size > 0 { - if _, ok := pk.LinkHeader().Consume(size); !ok { - t.Fatalf("pk.LinkHeader().Consume() = false, want true") - } - } - if size := test.network; size > 0 { - if _, ok := pk.NetworkHeader().Consume(size); !ok { - t.Fatalf("pk.NetworkHeader().Consume() = false, want true") - } - } - if size := test.transport; size > 0 { - if _, ok := pk.TransportHeader().Consume(size); !ok { - t.Fatalf("pk.TransportHeader().Consume() = false, want true") - } - } - - allHdrSize := test.link + test.network + test.transport - - // Check the after values for packet. - if got, want := pk.ReservedHeaderBytes(), 0; got != want { - t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), 0; got != want { - t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), allHdrSize; got != want { - t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) - } - if got, want := pk.Size(), len(test.data); got != want { - t.Errorf("After pk.Size() = %d, want %d", got, want) - } - // Check the after state of pk. - checkPacketContents(t, "After ", pk, packetContents{ - link: test.data[:test.link], - network: test.data[test.link:][:test.network], - transport: test.data[test.link+test.network:][:test.transport], - data: test.data[allHdrSize:], - }) - }) - } -} - -func TestPacketHeaderConsumeDataTooShort(t *testing.T) { - data := makeView(10) - - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - // Consume should fail if pkt.Data is too short. - if _, ok := pk.LinkHeader().Consume(11); ok { - t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") - } - if _, ok := pk.NetworkHeader().Consume(11); ok { - t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") - } - if _, ok := pk.TransportHeader().Consume(11); ok { - t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") - } - - // Check packet should look the same as initial packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - Data: buffer.View(data).ToVectorisedView(), - }) -} - -// This is a very obscure use-case seen in the code that verifies packets -// before sending them out. It tries to parse the headers to verify. -// PacketHeader was initially not designed to mix Push() and Consume(), but it -// works and it's been relied upon. Include a test here. -func TestPacketHeaderPushConsumeMixed(t *testing.T) { - link := makeView(10) - network := makeView(20) - data := makeView(30) - - initData := append([]byte(nil), network...) - initData = append(initData, data...) - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: len(link), - Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), - }) - - // 1. Consume network header - gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) - if !ok { - t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) - } - checkViewEqual(t, "gotNetwork", gotNetwork, network) - - // 2. Push link header - copy(pk.LinkHeader().Push(len(link)), link) - - checkPacketContents(t, "" /* prefix */, pk, packetContents{ - link: link, - network: network, - data: data, - }) -} - -func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { - link := makeView(10) - network := makeView(20) - data := makeView(30) - - initData := concatViews(network, data) - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: len(link), - Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), - }) - - // 1. Push link header - copy(pk.LinkHeader().Push(len(link)), link) - - checkPacketContents(t, "" /* prefix */, pk, packetContents{ - link: link, - data: initData, - }) - - // 2. Consume network header, with a number of bytes too large. - gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) - if ok { - t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) - } - - checkPacketContents(t, "" /* prefix */, pk, packetContents{ - link: link, - data: initData, - }) -} - -func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: headerSize * int(numHeaderType), - }) - - for _, h := range []PacketHeader{ - pk.TransportHeader(), - pk.NetworkHeader(), - pk.LinkHeader(), - } { - t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { - h.Push(headerSize) - - defer func() { recover() }() - h.Push(headerSize) - t.Fatal("Second push should have panicked") - }) - } -} - -func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), - }) - - for _, h := range []PacketHeader{ - pk.LinkHeader(), - pk.NetworkHeader(), - pk.TransportHeader(), - } { - t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { - if _, ok := h.Consume(headerSize); !ok { - t.Fatal("First consume should succeed") - } - - defer func() { recover() }() - h.Consume(headerSize) - t.Fatal("Second consume should have panicked") - }) - } -} - -func TestPacketHeaderPushThenConsumePanics(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: headerSize * int(numHeaderType), - }) - - for _, h := range []PacketHeader{ - pk.TransportHeader(), - pk.NetworkHeader(), - pk.LinkHeader(), - } { - t.Run(h.typ.String(), func(t *testing.T) { - h.Push(headerSize) - - defer func() { recover() }() - h.Consume(headerSize) - t.Fatal("Consume should have panicked") - }) - } -} - -func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), - }) - - for _, h := range []PacketHeader{ - pk.LinkHeader(), - pk.NetworkHeader(), - pk.TransportHeader(), - } { - t.Run(h.typ.String(), func(t *testing.T) { - h.Consume(headerSize) - - defer func() { recover() }() - h.Push(headerSize) - t.Fatal("Push should have panicked") - }) - } -} - -func TestPacketBufferData(t *testing.T) { - for _, tc := range []struct { - name string - makePkt func(*testing.T) *PacketBuffer - data string - }{ - { - name: "inbound packet", - makePkt: func(*testing.T) *PacketBuffer { - pkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv("aabbbbccccccDATA"), - }) - pkt.LinkHeader().Consume(2) - pkt.NetworkHeader().Consume(4) - pkt.TransportHeader().Consume(6) - return pkt - }, - data: "DATA", - }, - { - name: "outbound packet", - makePkt: func(*testing.T) *PacketBuffer { - pkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: 12, - Data: vv("DATA"), - }) - copy(pkt.TransportHeader().Push(6), []byte("cccccc")) - copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) - copy(pkt.LinkHeader().Push(2), []byte("aa")) - return pkt - }, - data: "DATA", - }, - } { - t.Run(tc.name, func(t *testing.T) { - // PullUp - for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { - pkt := tc.makePkt(t) - v, ok := pkt.Data().PullUp(n) - wantV := []byte(tc.data)[:n] - if !ok || !bytes.Equal(v, wantV) { - t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) - } - }) - } - t.Run("PullUpOutOfBounds", func(t *testing.T) { - n := len(tc.data) + 1 - pkt := tc.makePkt(t) - v, ok := pkt.Data().PullUp(n) - if ok || v != nil { - t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) - } - }) - - // DeleteFront - for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { - pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) - - checkData(t, pkt, []byte(tc.data)[n:]) - }) - } - - // CapLength - for _, n := range []int{0, 1, len(tc.data)} { - t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { - pkt := tc.makePkt(t) - pkt.Data().CapLength(n) - - want := []byte(tc.data) - if n < len(want) { - want = want[:n] - } - checkData(t, pkt, want) - }) - } - - // Views - t.Run("Views", func(t *testing.T) { - pkt := tc.makePkt(t) - checkData(t, pkt, []byte(tc.data)) - }) - - // AppendView - t.Run("AppendView", func(t *testing.T) { - s := "APPEND" - - pkt := tc.makePkt(t) - pkt.Data().AppendView(buffer.View(s)) - - checkData(t, pkt, []byte(tc.data+s)) - }) - - // ReadFromVV - for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { - s := "TO READ" - srcVV := vv(s, s) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromVV(&srcVV, n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) - } - - // ExtractVV - t.Run("ExtractVV", func(t *testing.T) { - pkt := tc.makePkt(t) - extractedVV := pkt.Data().ExtractVV() - - got := extractedVV.ToOwnedView() - want := []byte(tc.data) - if !bytes.Equal(got, want) { - t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) - } - }) - }) - } -} - -type packetContents struct { - link buffer.View - network buffer.View - transport buffer.View - data buffer.View -} - -func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { - t.Helper() - // Headers. - checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) - checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) - checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) - // Data. - checkData(t, pk, want.data) - // Whole packet. - checkViewEqual(t, prefix+"pk.Views()", - concatViews(pk.Views()...), - concatViews(want.link, want.network, want.transport, want.data)) - // PayloadSince. - checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(want.link, want.network, want.transport, want.data)) - checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(want.network, want.transport, want.data)) - checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(want.transport, want.data)) -} - -func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { - t.Helper() - reserved := opts.ReserveHeaderBytes - if got, want := pk.ReservedHeaderBytes(), reserved; got != want { - t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), reserved; got != want { - t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), 0; got != want { - t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) - } - data := opts.Data.ToView() - if got, want := pk.Size(), len(data); got != want { - t.Errorf("Initial pk.Size() = %d, want %d", got, want) - } - checkPacketContents(t, "Initial ", pk, packetContents{ - data: data, - }) -} - -func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { - t.Helper() - checkViewEqual(t, name+".View()", h.View(), want) -} - -func checkViewEqual(t *testing.T, what string, got, want buffer.View) { - t.Helper() - if !bytes.Equal(got, want) { - t.Errorf("%s = %x, want %x", what, got, want) - } -} - -func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { - t.Helper() - if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) - } - if got := pkt.Data().Size(); got != len(want) { - t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) - } - - t.Run("AsRange", func(t *testing.T) { - // Full range - checkRange(t, pkt.Data().AsRange(), want) - - // SubRange - for _, off := range []int{0, 1, len(want), len(want) + 1} { - t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { - // Empty when off is greater than the size of range. - var sub []byte - if off < len(want) { - sub = want[off:] - } - checkRange(t, pkt.Data().AsRange().SubRange(off), sub) - }) - } - - // Capped - for _, n := range []int{0, 1, len(want), len(want) + 1} { - t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { - sub := want - if n < len(sub) { - sub = sub[:n] - } - checkRange(t, pkt.Data().AsRange().Capped(n), sub) - }) - } - }) -} - -func checkRange(t *testing.T, r Range, data []byte) { - if got, want := r.Size(), len(data); got != want { - t.Errorf("r.Size() = %d, want %d", got, want) - } - if got := r.AsView(); !bytes.Equal(got, data) { - t.Errorf("r.AsView() = %x, want %x", got, data) - } - if got := r.ToOwnedView(); !bytes.Equal(got, data) { - t.Errorf("r.ToOwnedView() = %x, want %x", got, data) - } - if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { - t.Errorf("r.Checksum() = %x, want %x", got, want) - } -} - -func vv(pieces ...string) buffer.VectorisedView { - var views []buffer.View - var size int - for _, p := range pieces { - v := buffer.View([]byte(p)) - size += len(v) - views = append(views, v) - } - return buffer.NewVectorisedView(size, views) -} - -func makeView(size int) buffer.View { - b := byte(size) - return bytes.Repeat([]byte{b}, size) -} - -func concatViews(views ...buffer.View) buffer.View { - var all buffer.View - for _, v := range views { - all = append(all, v...) - } - return all -} diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go new file mode 100644 index 000000000..54f2574ac --- /dev/null +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -0,0 +1,1288 @@ +// automatically generated by stateify. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *tuple) StateTypeName() string { + return "pkg/tcpip/stack.tuple" +} + +func (t *tuple) StateFields() []string { + return []string{ + "tupleEntry", + "tupleID", + "conn", + "direction", + } +} + +func (t *tuple) beforeSave() {} + +// +checklocksignore +func (t *tuple) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.tupleEntry) + stateSinkObject.Save(1, &t.tupleID) + stateSinkObject.Save(2, &t.conn) + stateSinkObject.Save(3, &t.direction) +} + +func (t *tuple) afterLoad() {} + +// +checklocksignore +func (t *tuple) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.tupleEntry) + stateSourceObject.Load(1, &t.tupleID) + stateSourceObject.Load(2, &t.conn) + stateSourceObject.Load(3, &t.direction) +} + +func (ti *tupleID) StateTypeName() string { + return "pkg/tcpip/stack.tupleID" +} + +func (ti *tupleID) StateFields() []string { + return []string{ + "srcAddr", + "srcPort", + "dstAddr", + "dstPort", + "transProto", + "netProto", + } +} + +func (ti *tupleID) beforeSave() {} + +// +checklocksignore +func (ti *tupleID) StateSave(stateSinkObject state.Sink) { + ti.beforeSave() + stateSinkObject.Save(0, &ti.srcAddr) + stateSinkObject.Save(1, &ti.srcPort) + stateSinkObject.Save(2, &ti.dstAddr) + stateSinkObject.Save(3, &ti.dstPort) + stateSinkObject.Save(4, &ti.transProto) + stateSinkObject.Save(5, &ti.netProto) +} + +func (ti *tupleID) afterLoad() {} + +// +checklocksignore +func (ti *tupleID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ti.srcAddr) + stateSourceObject.Load(1, &ti.srcPort) + stateSourceObject.Load(2, &ti.dstAddr) + stateSourceObject.Load(3, &ti.dstPort) + stateSourceObject.Load(4, &ti.transProto) + stateSourceObject.Load(5, &ti.netProto) +} + +func (cn *conn) StateTypeName() string { + return "pkg/tcpip/stack.conn" +} + +func (cn *conn) StateFields() []string { + return []string{ + "original", + "reply", + "manip", + "tcbHook", + "tcb", + "lastUsed", + } +} + +func (cn *conn) beforeSave() {} + +// +checklocksignore +func (cn *conn) StateSave(stateSinkObject state.Sink) { + cn.beforeSave() + var lastUsedValue unixTime + lastUsedValue = cn.saveLastUsed() + stateSinkObject.SaveValue(5, lastUsedValue) + stateSinkObject.Save(0, &cn.original) + stateSinkObject.Save(1, &cn.reply) + stateSinkObject.Save(2, &cn.manip) + stateSinkObject.Save(3, &cn.tcbHook) + stateSinkObject.Save(4, &cn.tcb) +} + +func (cn *conn) afterLoad() {} + +// +checklocksignore +func (cn *conn) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &cn.original) + stateSourceObject.Load(1, &cn.reply) + stateSourceObject.Load(2, &cn.manip) + stateSourceObject.Load(3, &cn.tcbHook) + stateSourceObject.Load(4, &cn.tcb) + stateSourceObject.LoadValue(5, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) +} + +func (ct *ConnTrack) StateTypeName() string { + return "pkg/tcpip/stack.ConnTrack" +} + +func (ct *ConnTrack) StateFields() []string { + return []string{ + "seed", + "buckets", + } +} + +// +checklocksignore +func (ct *ConnTrack) StateSave(stateSinkObject state.Sink) { + ct.beforeSave() + stateSinkObject.Save(0, &ct.seed) + stateSinkObject.Save(1, &ct.buckets) +} + +func (ct *ConnTrack) afterLoad() {} + +// +checklocksignore +func (ct *ConnTrack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ct.seed) + stateSourceObject.Load(1, &ct.buckets) +} + +func (b *bucket) StateTypeName() string { + return "pkg/tcpip/stack.bucket" +} + +func (b *bucket) StateFields() []string { + return []string{ + "tuples", + } +} + +func (b *bucket) beforeSave() {} + +// +checklocksignore +func (b *bucket) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.tuples) +} + +func (b *bucket) afterLoad() {} + +// +checklocksignore +func (b *bucket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.tuples) +} + +func (u *unixTime) StateTypeName() string { + return "pkg/tcpip/stack.unixTime" +} + +func (u *unixTime) StateFields() []string { + return []string{ + "second", + "nano", + } +} + +func (u *unixTime) beforeSave() {} + +// +checklocksignore +func (u *unixTime) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.second) + stateSinkObject.Save(1, &u.nano) +} + +func (u *unixTime) afterLoad() {} + +// +checklocksignore +func (u *unixTime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.second) + stateSourceObject.Load(1, &u.nano) +} + +func (it *IPTables) StateTypeName() string { + return "pkg/tcpip/stack.IPTables" +} + +func (it *IPTables) StateFields() []string { + return []string{ + "mu", + "v4Tables", + "v6Tables", + "modified", + "priorities", + "connections", + "reaperDone", + } +} + +// +checklocksignore +func (it *IPTables) StateSave(stateSinkObject state.Sink) { + it.beforeSave() + stateSinkObject.Save(0, &it.mu) + stateSinkObject.Save(1, &it.v4Tables) + stateSinkObject.Save(2, &it.v6Tables) + stateSinkObject.Save(3, &it.modified) + stateSinkObject.Save(4, &it.priorities) + stateSinkObject.Save(5, &it.connections) + stateSinkObject.Save(6, &it.reaperDone) +} + +// +checklocksignore +func (it *IPTables) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &it.mu) + stateSourceObject.Load(1, &it.v4Tables) + stateSourceObject.Load(2, &it.v6Tables) + stateSourceObject.Load(3, &it.modified) + stateSourceObject.Load(4, &it.priorities) + stateSourceObject.Load(5, &it.connections) + stateSourceObject.Load(6, &it.reaperDone) + stateSourceObject.AfterLoad(it.afterLoad) +} + +func (table *Table) StateTypeName() string { + return "pkg/tcpip/stack.Table" +} + +func (table *Table) StateFields() []string { + return []string{ + "Rules", + "BuiltinChains", + "Underflows", + } +} + +func (table *Table) beforeSave() {} + +// +checklocksignore +func (table *Table) StateSave(stateSinkObject state.Sink) { + table.beforeSave() + stateSinkObject.Save(0, &table.Rules) + stateSinkObject.Save(1, &table.BuiltinChains) + stateSinkObject.Save(2, &table.Underflows) +} + +func (table *Table) afterLoad() {} + +// +checklocksignore +func (table *Table) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &table.Rules) + stateSourceObject.Load(1, &table.BuiltinChains) + stateSourceObject.Load(2, &table.Underflows) +} + +func (r *Rule) StateTypeName() string { + return "pkg/tcpip/stack.Rule" +} + +func (r *Rule) StateFields() []string { + return []string{ + "Filter", + "Matchers", + "Target", + } +} + +func (r *Rule) beforeSave() {} + +// +checklocksignore +func (r *Rule) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Filter) + stateSinkObject.Save(1, &r.Matchers) + stateSinkObject.Save(2, &r.Target) +} + +func (r *Rule) afterLoad() {} + +// +checklocksignore +func (r *Rule) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Filter) + stateSourceObject.Load(1, &r.Matchers) + stateSourceObject.Load(2, &r.Target) +} + +func (fl *IPHeaderFilter) StateTypeName() string { + return "pkg/tcpip/stack.IPHeaderFilter" +} + +func (fl *IPHeaderFilter) StateFields() []string { + return []string{ + "Protocol", + "CheckProtocol", + "Dst", + "DstMask", + "DstInvert", + "Src", + "SrcMask", + "SrcInvert", + "InputInterface", + "InputInterfaceMask", + "InputInterfaceInvert", + "OutputInterface", + "OutputInterfaceMask", + "OutputInterfaceInvert", + } +} + +func (fl *IPHeaderFilter) beforeSave() {} + +// +checklocksignore +func (fl *IPHeaderFilter) StateSave(stateSinkObject state.Sink) { + fl.beforeSave() + stateSinkObject.Save(0, &fl.Protocol) + stateSinkObject.Save(1, &fl.CheckProtocol) + stateSinkObject.Save(2, &fl.Dst) + stateSinkObject.Save(3, &fl.DstMask) + stateSinkObject.Save(4, &fl.DstInvert) + stateSinkObject.Save(5, &fl.Src) + stateSinkObject.Save(6, &fl.SrcMask) + stateSinkObject.Save(7, &fl.SrcInvert) + stateSinkObject.Save(8, &fl.InputInterface) + stateSinkObject.Save(9, &fl.InputInterfaceMask) + stateSinkObject.Save(10, &fl.InputInterfaceInvert) + stateSinkObject.Save(11, &fl.OutputInterface) + stateSinkObject.Save(12, &fl.OutputInterfaceMask) + stateSinkObject.Save(13, &fl.OutputInterfaceInvert) +} + +func (fl *IPHeaderFilter) afterLoad() {} + +// +checklocksignore +func (fl *IPHeaderFilter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fl.Protocol) + stateSourceObject.Load(1, &fl.CheckProtocol) + stateSourceObject.Load(2, &fl.Dst) + stateSourceObject.Load(3, &fl.DstMask) + stateSourceObject.Load(4, &fl.DstInvert) + stateSourceObject.Load(5, &fl.Src) + stateSourceObject.Load(6, &fl.SrcMask) + stateSourceObject.Load(7, &fl.SrcInvert) + stateSourceObject.Load(8, &fl.InputInterface) + stateSourceObject.Load(9, &fl.InputInterfaceMask) + stateSourceObject.Load(10, &fl.InputInterfaceInvert) + stateSourceObject.Load(11, &fl.OutputInterface) + stateSourceObject.Load(12, &fl.OutputInterfaceMask) + stateSourceObject.Load(13, &fl.OutputInterfaceInvert) +} + +func (l *neighborEntryList) StateTypeName() string { + return "pkg/tcpip/stack.neighborEntryList" +} + +func (l *neighborEntryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *neighborEntryList) beforeSave() {} + +// +checklocksignore +func (l *neighborEntryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *neighborEntryList) afterLoad() {} + +// +checklocksignore +func (l *neighborEntryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *neighborEntryEntry) StateTypeName() string { + return "pkg/tcpip/stack.neighborEntryEntry" +} + +func (e *neighborEntryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *neighborEntryEntry) beforeSave() {} + +// +checklocksignore +func (e *neighborEntryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *neighborEntryEntry) afterLoad() {} + +// +checklocksignore +func (e *neighborEntryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (p *PacketBufferList) StateTypeName() string { + return "pkg/tcpip/stack.PacketBufferList" +} + +func (p *PacketBufferList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (p *PacketBufferList) beforeSave() {} + +// +checklocksignore +func (p *PacketBufferList) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.head) + stateSinkObject.Save(1, &p.tail) +} + +func (p *PacketBufferList) afterLoad() {} + +// +checklocksignore +func (p *PacketBufferList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.head) + stateSourceObject.Load(1, &p.tail) +} + +func (e *PacketBufferEntry) StateTypeName() string { + return "pkg/tcpip/stack.PacketBufferEntry" +} + +func (e *PacketBufferEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *PacketBufferEntry) beforeSave() {} + +// +checklocksignore +func (e *PacketBufferEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *PacketBufferEntry) afterLoad() {} + +// +checklocksignore +func (e *PacketBufferEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (t *TransportEndpointID) StateTypeName() string { + return "pkg/tcpip/stack.TransportEndpointID" +} + +func (t *TransportEndpointID) StateFields() []string { + return []string{ + "LocalPort", + "LocalAddress", + "RemotePort", + "RemoteAddress", + } +} + +func (t *TransportEndpointID) beforeSave() {} + +// +checklocksignore +func (t *TransportEndpointID) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LocalPort) + stateSinkObject.Save(1, &t.LocalAddress) + stateSinkObject.Save(2, &t.RemotePort) + stateSinkObject.Save(3, &t.RemoteAddress) +} + +func (t *TransportEndpointID) afterLoad() {} + +// +checklocksignore +func (t *TransportEndpointID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LocalPort) + stateSourceObject.Load(1, &t.LocalAddress) + stateSourceObject.Load(2, &t.RemotePort) + stateSourceObject.Load(3, &t.RemoteAddress) +} + +func (g *GSOType) StateTypeName() string { + return "pkg/tcpip/stack.GSOType" +} + +func (g *GSOType) StateFields() []string { + return nil +} + +func (g *GSO) StateTypeName() string { + return "pkg/tcpip/stack.GSO" +} + +func (g *GSO) StateFields() []string { + return []string{ + "Type", + "NeedsCsum", + "CsumOffset", + "MSS", + "L3HdrLen", + "MaxSize", + } +} + +func (g *GSO) beforeSave() {} + +// +checklocksignore +func (g *GSO) StateSave(stateSinkObject state.Sink) { + g.beforeSave() + stateSinkObject.Save(0, &g.Type) + stateSinkObject.Save(1, &g.NeedsCsum) + stateSinkObject.Save(2, &g.CsumOffset) + stateSinkObject.Save(3, &g.MSS) + stateSinkObject.Save(4, &g.L3HdrLen) + stateSinkObject.Save(5, &g.MaxSize) +} + +func (g *GSO) afterLoad() {} + +// +checklocksignore +func (g *GSO) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &g.Type) + stateSourceObject.Load(1, &g.NeedsCsum) + stateSourceObject.Load(2, &g.CsumOffset) + stateSourceObject.Load(3, &g.MSS) + stateSourceObject.Load(4, &g.L3HdrLen) + stateSourceObject.Load(5, &g.MaxSize) +} + +func (t *TransportEndpointInfo) StateTypeName() string { + return "pkg/tcpip/stack.TransportEndpointInfo" +} + +func (t *TransportEndpointInfo) StateFields() []string { + return []string{ + "NetProto", + "TransProto", + "ID", + "BindNICID", + "BindAddr", + "RegisterNICID", + } +} + +func (t *TransportEndpointInfo) beforeSave() {} + +// +checklocksignore +func (t *TransportEndpointInfo) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.NetProto) + stateSinkObject.Save(1, &t.TransProto) + stateSinkObject.Save(2, &t.ID) + stateSinkObject.Save(3, &t.BindNICID) + stateSinkObject.Save(4, &t.BindAddr) + stateSinkObject.Save(5, &t.RegisterNICID) +} + +func (t *TransportEndpointInfo) afterLoad() {} + +// +checklocksignore +func (t *TransportEndpointInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.NetProto) + stateSourceObject.Load(1, &t.TransProto) + stateSourceObject.Load(2, &t.ID) + stateSourceObject.Load(3, &t.BindNICID) + stateSourceObject.Load(4, &t.BindAddr) + stateSourceObject.Load(5, &t.RegisterNICID) +} + +func (t *TCPCubicState) StateTypeName() string { + return "pkg/tcpip/stack.TCPCubicState" +} + +func (t *TCPCubicState) StateFields() []string { + return []string{ + "WLastMax", + "WMax", + "T", + "TimeSinceLastCongestion", + "C", + "K", + "Beta", + "WC", + "WEst", + } +} + +func (t *TCPCubicState) beforeSave() {} + +// +checklocksignore +func (t *TCPCubicState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.WLastMax) + stateSinkObject.Save(1, &t.WMax) + stateSinkObject.Save(2, &t.T) + stateSinkObject.Save(3, &t.TimeSinceLastCongestion) + stateSinkObject.Save(4, &t.C) + stateSinkObject.Save(5, &t.K) + stateSinkObject.Save(6, &t.Beta) + stateSinkObject.Save(7, &t.WC) + stateSinkObject.Save(8, &t.WEst) +} + +func (t *TCPCubicState) afterLoad() {} + +// +checklocksignore +func (t *TCPCubicState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.WLastMax) + stateSourceObject.Load(1, &t.WMax) + stateSourceObject.Load(2, &t.T) + stateSourceObject.Load(3, &t.TimeSinceLastCongestion) + stateSourceObject.Load(4, &t.C) + stateSourceObject.Load(5, &t.K) + stateSourceObject.Load(6, &t.Beta) + stateSourceObject.Load(7, &t.WC) + stateSourceObject.Load(8, &t.WEst) +} + +func (t *TCPRACKState) StateTypeName() string { + return "pkg/tcpip/stack.TCPRACKState" +} + +func (t *TCPRACKState) StateFields() []string { + return []string{ + "XmitTime", + "EndSequence", + "FACK", + "RTT", + "Reord", + "DSACKSeen", + "ReoWnd", + "ReoWndIncr", + "ReoWndPersist", + "RTTSeq", + } +} + +func (t *TCPRACKState) beforeSave() {} + +// +checklocksignore +func (t *TCPRACKState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.XmitTime) + stateSinkObject.Save(1, &t.EndSequence) + stateSinkObject.Save(2, &t.FACK) + stateSinkObject.Save(3, &t.RTT) + stateSinkObject.Save(4, &t.Reord) + stateSinkObject.Save(5, &t.DSACKSeen) + stateSinkObject.Save(6, &t.ReoWnd) + stateSinkObject.Save(7, &t.ReoWndIncr) + stateSinkObject.Save(8, &t.ReoWndPersist) + stateSinkObject.Save(9, &t.RTTSeq) +} + +func (t *TCPRACKState) afterLoad() {} + +// +checklocksignore +func (t *TCPRACKState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.XmitTime) + stateSourceObject.Load(1, &t.EndSequence) + stateSourceObject.Load(2, &t.FACK) + stateSourceObject.Load(3, &t.RTT) + stateSourceObject.Load(4, &t.Reord) + stateSourceObject.Load(5, &t.DSACKSeen) + stateSourceObject.Load(6, &t.ReoWnd) + stateSourceObject.Load(7, &t.ReoWndIncr) + stateSourceObject.Load(8, &t.ReoWndPersist) + stateSourceObject.Load(9, &t.RTTSeq) +} + +func (t *TCPEndpointID) StateTypeName() string { + return "pkg/tcpip/stack.TCPEndpointID" +} + +func (t *TCPEndpointID) StateFields() []string { + return []string{ + "LocalPort", + "LocalAddress", + "RemotePort", + "RemoteAddress", + } +} + +func (t *TCPEndpointID) beforeSave() {} + +// +checklocksignore +func (t *TCPEndpointID) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LocalPort) + stateSinkObject.Save(1, &t.LocalAddress) + stateSinkObject.Save(2, &t.RemotePort) + stateSinkObject.Save(3, &t.RemoteAddress) +} + +func (t *TCPEndpointID) afterLoad() {} + +// +checklocksignore +func (t *TCPEndpointID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LocalPort) + stateSourceObject.Load(1, &t.LocalAddress) + stateSourceObject.Load(2, &t.RemotePort) + stateSourceObject.Load(3, &t.RemoteAddress) +} + +func (t *TCPFastRecoveryState) StateTypeName() string { + return "pkg/tcpip/stack.TCPFastRecoveryState" +} + +func (t *TCPFastRecoveryState) StateFields() []string { + return []string{ + "Active", + "First", + "Last", + "MaxCwnd", + "HighRxt", + "RescueRxt", + } +} + +func (t *TCPFastRecoveryState) beforeSave() {} + +// +checklocksignore +func (t *TCPFastRecoveryState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Active) + stateSinkObject.Save(1, &t.First) + stateSinkObject.Save(2, &t.Last) + stateSinkObject.Save(3, &t.MaxCwnd) + stateSinkObject.Save(4, &t.HighRxt) + stateSinkObject.Save(5, &t.RescueRxt) +} + +func (t *TCPFastRecoveryState) afterLoad() {} + +// +checklocksignore +func (t *TCPFastRecoveryState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Active) + stateSourceObject.Load(1, &t.First) + stateSourceObject.Load(2, &t.Last) + stateSourceObject.Load(3, &t.MaxCwnd) + stateSourceObject.Load(4, &t.HighRxt) + stateSourceObject.Load(5, &t.RescueRxt) +} + +func (t *TCPReceiverState) StateTypeName() string { + return "pkg/tcpip/stack.TCPReceiverState" +} + +func (t *TCPReceiverState) StateFields() []string { + return []string{ + "RcvNxt", + "RcvAcc", + "RcvWndScale", + "PendingBufUsed", + } +} + +func (t *TCPReceiverState) beforeSave() {} + +// +checklocksignore +func (t *TCPReceiverState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.RcvNxt) + stateSinkObject.Save(1, &t.RcvAcc) + stateSinkObject.Save(2, &t.RcvWndScale) + stateSinkObject.Save(3, &t.PendingBufUsed) +} + +func (t *TCPReceiverState) afterLoad() {} + +// +checklocksignore +func (t *TCPReceiverState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.RcvNxt) + stateSourceObject.Load(1, &t.RcvAcc) + stateSourceObject.Load(2, &t.RcvWndScale) + stateSourceObject.Load(3, &t.PendingBufUsed) +} + +func (t *TCPRTTState) StateTypeName() string { + return "pkg/tcpip/stack.TCPRTTState" +} + +func (t *TCPRTTState) StateFields() []string { + return []string{ + "SRTT", + "RTTVar", + "SRTTInited", + } +} + +func (t *TCPRTTState) beforeSave() {} + +// +checklocksignore +func (t *TCPRTTState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.SRTT) + stateSinkObject.Save(1, &t.RTTVar) + stateSinkObject.Save(2, &t.SRTTInited) +} + +func (t *TCPRTTState) afterLoad() {} + +// +checklocksignore +func (t *TCPRTTState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.SRTT) + stateSourceObject.Load(1, &t.RTTVar) + stateSourceObject.Load(2, &t.SRTTInited) +} + +func (t *TCPSenderState) StateTypeName() string { + return "pkg/tcpip/stack.TCPSenderState" +} + +func (t *TCPSenderState) StateFields() []string { + return []string{ + "LastSendTime", + "DupAckCount", + "SndCwnd", + "Ssthresh", + "SndCAAckCount", + "Outstanding", + "SackedOut", + "SndWnd", + "SndUna", + "SndNxt", + "RTTMeasureSeqNum", + "RTTMeasureTime", + "Closed", + "RTO", + "RTTState", + "MaxPayloadSize", + "SndWndScale", + "MaxSentAck", + "FastRecovery", + "Cubic", + "RACKState", + } +} + +func (t *TCPSenderState) beforeSave() {} + +// +checklocksignore +func (t *TCPSenderState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LastSendTime) + stateSinkObject.Save(1, &t.DupAckCount) + stateSinkObject.Save(2, &t.SndCwnd) + stateSinkObject.Save(3, &t.Ssthresh) + stateSinkObject.Save(4, &t.SndCAAckCount) + stateSinkObject.Save(5, &t.Outstanding) + stateSinkObject.Save(6, &t.SackedOut) + stateSinkObject.Save(7, &t.SndWnd) + stateSinkObject.Save(8, &t.SndUna) + stateSinkObject.Save(9, &t.SndNxt) + stateSinkObject.Save(10, &t.RTTMeasureSeqNum) + stateSinkObject.Save(11, &t.RTTMeasureTime) + stateSinkObject.Save(12, &t.Closed) + stateSinkObject.Save(13, &t.RTO) + stateSinkObject.Save(14, &t.RTTState) + stateSinkObject.Save(15, &t.MaxPayloadSize) + stateSinkObject.Save(16, &t.SndWndScale) + stateSinkObject.Save(17, &t.MaxSentAck) + stateSinkObject.Save(18, &t.FastRecovery) + stateSinkObject.Save(19, &t.Cubic) + stateSinkObject.Save(20, &t.RACKState) +} + +func (t *TCPSenderState) afterLoad() {} + +// +checklocksignore +func (t *TCPSenderState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LastSendTime) + stateSourceObject.Load(1, &t.DupAckCount) + stateSourceObject.Load(2, &t.SndCwnd) + stateSourceObject.Load(3, &t.Ssthresh) + stateSourceObject.Load(4, &t.SndCAAckCount) + stateSourceObject.Load(5, &t.Outstanding) + stateSourceObject.Load(6, &t.SackedOut) + stateSourceObject.Load(7, &t.SndWnd) + stateSourceObject.Load(8, &t.SndUna) + stateSourceObject.Load(9, &t.SndNxt) + stateSourceObject.Load(10, &t.RTTMeasureSeqNum) + stateSourceObject.Load(11, &t.RTTMeasureTime) + stateSourceObject.Load(12, &t.Closed) + stateSourceObject.Load(13, &t.RTO) + stateSourceObject.Load(14, &t.RTTState) + stateSourceObject.Load(15, &t.MaxPayloadSize) + stateSourceObject.Load(16, &t.SndWndScale) + stateSourceObject.Load(17, &t.MaxSentAck) + stateSourceObject.Load(18, &t.FastRecovery) + stateSourceObject.Load(19, &t.Cubic) + stateSourceObject.Load(20, &t.RACKState) +} + +func (t *TCPSACKInfo) StateTypeName() string { + return "pkg/tcpip/stack.TCPSACKInfo" +} + +func (t *TCPSACKInfo) StateFields() []string { + return []string{ + "Blocks", + "ReceivedBlocks", + "MaxSACKED", + } +} + +func (t *TCPSACKInfo) beforeSave() {} + +// +checklocksignore +func (t *TCPSACKInfo) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Blocks) + stateSinkObject.Save(1, &t.ReceivedBlocks) + stateSinkObject.Save(2, &t.MaxSACKED) +} + +func (t *TCPSACKInfo) afterLoad() {} + +// +checklocksignore +func (t *TCPSACKInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Blocks) + stateSourceObject.Load(1, &t.ReceivedBlocks) + stateSourceObject.Load(2, &t.MaxSACKED) +} + +func (r *RcvBufAutoTuneParams) StateTypeName() string { + return "pkg/tcpip/stack.RcvBufAutoTuneParams" +} + +func (r *RcvBufAutoTuneParams) StateFields() []string { + return []string{ + "MeasureTime", + "CopiedBytes", + "PrevCopiedBytes", + "RcvBufSize", + "RTT", + "RTTVar", + "RTTMeasureSeqNumber", + "RTTMeasureTime", + "Disabled", + } +} + +func (r *RcvBufAutoTuneParams) beforeSave() {} + +// +checklocksignore +func (r *RcvBufAutoTuneParams) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.MeasureTime) + stateSinkObject.Save(1, &r.CopiedBytes) + stateSinkObject.Save(2, &r.PrevCopiedBytes) + stateSinkObject.Save(3, &r.RcvBufSize) + stateSinkObject.Save(4, &r.RTT) + stateSinkObject.Save(5, &r.RTTVar) + stateSinkObject.Save(6, &r.RTTMeasureSeqNumber) + stateSinkObject.Save(7, &r.RTTMeasureTime) + stateSinkObject.Save(8, &r.Disabled) +} + +func (r *RcvBufAutoTuneParams) afterLoad() {} + +// +checklocksignore +func (r *RcvBufAutoTuneParams) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.MeasureTime) + stateSourceObject.Load(1, &r.CopiedBytes) + stateSourceObject.Load(2, &r.PrevCopiedBytes) + stateSourceObject.Load(3, &r.RcvBufSize) + stateSourceObject.Load(4, &r.RTT) + stateSourceObject.Load(5, &r.RTTVar) + stateSourceObject.Load(6, &r.RTTMeasureSeqNumber) + stateSourceObject.Load(7, &r.RTTMeasureTime) + stateSourceObject.Load(8, &r.Disabled) +} + +func (t *TCPRcvBufState) StateTypeName() string { + return "pkg/tcpip/stack.TCPRcvBufState" +} + +func (t *TCPRcvBufState) StateFields() []string { + return []string{ + "RcvBufUsed", + "RcvAutoParams", + "RcvClosed", + } +} + +func (t *TCPRcvBufState) beforeSave() {} + +// +checklocksignore +func (t *TCPRcvBufState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.RcvBufUsed) + stateSinkObject.Save(1, &t.RcvAutoParams) + stateSinkObject.Save(2, &t.RcvClosed) +} + +func (t *TCPRcvBufState) afterLoad() {} + +// +checklocksignore +func (t *TCPRcvBufState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.RcvBufUsed) + stateSourceObject.Load(1, &t.RcvAutoParams) + stateSourceObject.Load(2, &t.RcvClosed) +} + +func (t *TCPSndBufState) StateTypeName() string { + return "pkg/tcpip/stack.TCPSndBufState" +} + +func (t *TCPSndBufState) StateFields() []string { + return []string{ + "SndBufSize", + "SndBufUsed", + "SndClosed", + "PacketTooBigCount", + "SndMTU", + "AutoTuneSndBufDisabled", + } +} + +func (t *TCPSndBufState) beforeSave() {} + +// +checklocksignore +func (t *TCPSndBufState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.SndBufSize) + stateSinkObject.Save(1, &t.SndBufUsed) + stateSinkObject.Save(2, &t.SndClosed) + stateSinkObject.Save(3, &t.PacketTooBigCount) + stateSinkObject.Save(4, &t.SndMTU) + stateSinkObject.Save(5, &t.AutoTuneSndBufDisabled) +} + +func (t *TCPSndBufState) afterLoad() {} + +// +checklocksignore +func (t *TCPSndBufState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.SndBufSize) + stateSourceObject.Load(1, &t.SndBufUsed) + stateSourceObject.Load(2, &t.SndClosed) + stateSourceObject.Load(3, &t.PacketTooBigCount) + stateSourceObject.Load(4, &t.SndMTU) + stateSourceObject.Load(5, &t.AutoTuneSndBufDisabled) +} + +func (t *TCPEndpointStateInner) StateTypeName() string { + return "pkg/tcpip/stack.TCPEndpointStateInner" +} + +func (t *TCPEndpointStateInner) StateFields() []string { + return []string{ + "TSOffset", + "SACKPermitted", + "SendTSOk", + "RecentTS", + } +} + +func (t *TCPEndpointStateInner) beforeSave() {} + +// +checklocksignore +func (t *TCPEndpointStateInner) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.TSOffset) + stateSinkObject.Save(1, &t.SACKPermitted) + stateSinkObject.Save(2, &t.SendTSOk) + stateSinkObject.Save(3, &t.RecentTS) +} + +func (t *TCPEndpointStateInner) afterLoad() {} + +// +checklocksignore +func (t *TCPEndpointStateInner) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.TSOffset) + stateSourceObject.Load(1, &t.SACKPermitted) + stateSourceObject.Load(2, &t.SendTSOk) + stateSourceObject.Load(3, &t.RecentTS) +} + +func (t *TCPEndpointState) StateTypeName() string { + return "pkg/tcpip/stack.TCPEndpointState" +} + +func (t *TCPEndpointState) StateFields() []string { + return []string{ + "TCPEndpointStateInner", + "ID", + "SegTime", + "RcvBufState", + "SndBufState", + "SACK", + "Receiver", + "Sender", + } +} + +func (t *TCPEndpointState) beforeSave() {} + +// +checklocksignore +func (t *TCPEndpointState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.TCPEndpointStateInner) + stateSinkObject.Save(1, &t.ID) + stateSinkObject.Save(2, &t.SegTime) + stateSinkObject.Save(3, &t.RcvBufState) + stateSinkObject.Save(4, &t.SndBufState) + stateSinkObject.Save(5, &t.SACK) + stateSinkObject.Save(6, &t.Receiver) + stateSinkObject.Save(7, &t.Sender) +} + +func (t *TCPEndpointState) afterLoad() {} + +// +checklocksignore +func (t *TCPEndpointState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.TCPEndpointStateInner) + stateSourceObject.Load(1, &t.ID) + stateSourceObject.Load(2, &t.SegTime) + stateSourceObject.Load(3, &t.RcvBufState) + stateSourceObject.Load(4, &t.SndBufState) + stateSourceObject.Load(5, &t.SACK) + stateSourceObject.Load(6, &t.Receiver) + stateSourceObject.Load(7, &t.Sender) +} + +func (ep *multiPortEndpoint) StateTypeName() string { + return "pkg/tcpip/stack.multiPortEndpoint" +} + +func (ep *multiPortEndpoint) StateFields() []string { + return []string{ + "demux", + "netProto", + "transProto", + "endpoints", + "flags", + } +} + +func (ep *multiPortEndpoint) beforeSave() {} + +// +checklocksignore +func (ep *multiPortEndpoint) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + stateSinkObject.Save(0, &ep.demux) + stateSinkObject.Save(1, &ep.netProto) + stateSinkObject.Save(2, &ep.transProto) + stateSinkObject.Save(3, &ep.endpoints) + stateSinkObject.Save(4, &ep.flags) +} + +func (ep *multiPortEndpoint) afterLoad() {} + +// +checklocksignore +func (ep *multiPortEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.demux) + stateSourceObject.Load(1, &ep.netProto) + stateSourceObject.Load(2, &ep.transProto) + stateSourceObject.Load(3, &ep.endpoints) + stateSourceObject.Load(4, &ep.flags) +} + +func (l *tupleList) StateTypeName() string { + return "pkg/tcpip/stack.tupleList" +} + +func (l *tupleList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *tupleList) beforeSave() {} + +// +checklocksignore +func (l *tupleList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *tupleList) afterLoad() {} + +// +checklocksignore +func (l *tupleList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *tupleEntry) StateTypeName() string { + return "pkg/tcpip/stack.tupleEntry" +} + +func (e *tupleEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *tupleEntry) beforeSave() {} + +// +checklocksignore +func (e *tupleEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *tupleEntry) afterLoad() {} + +// +checklocksignore +func (e *tupleEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*tuple)(nil)) + state.Register((*tupleID)(nil)) + state.Register((*conn)(nil)) + state.Register((*ConnTrack)(nil)) + state.Register((*bucket)(nil)) + state.Register((*unixTime)(nil)) + state.Register((*IPTables)(nil)) + state.Register((*Table)(nil)) + state.Register((*Rule)(nil)) + state.Register((*IPHeaderFilter)(nil)) + state.Register((*neighborEntryList)(nil)) + state.Register((*neighborEntryEntry)(nil)) + state.Register((*PacketBufferList)(nil)) + state.Register((*PacketBufferEntry)(nil)) + state.Register((*TransportEndpointID)(nil)) + state.Register((*GSOType)(nil)) + state.Register((*GSO)(nil)) + state.Register((*TransportEndpointInfo)(nil)) + state.Register((*TCPCubicState)(nil)) + state.Register((*TCPRACKState)(nil)) + state.Register((*TCPEndpointID)(nil)) + state.Register((*TCPFastRecoveryState)(nil)) + state.Register((*TCPReceiverState)(nil)) + state.Register((*TCPRTTState)(nil)) + state.Register((*TCPSenderState)(nil)) + state.Register((*TCPSACKInfo)(nil)) + state.Register((*RcvBufAutoTuneParams)(nil)) + state.Register((*TCPRcvBufState)(nil)) + state.Register((*TCPSndBufState)(nil)) + state.Register((*TCPEndpointStateInner)(nil)) + state.Register((*TCPEndpointState)(nil)) + state.Register((*multiPortEndpoint)(nil)) + state.Register((*tupleList)(nil)) + state.Register((*tupleEntry)(nil)) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go deleted file mode 100644 index 3089c0ef4..000000000 --- a/pkg/tcpip/stack/stack_test.go +++ /dev/null @@ -1,4564 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package stack_test contains tests for the stack. It is in its own package so -// that the tests can also validate that all definitions needed to implement -// transport and network protocols are properly exported by the stack package. -package stack_test - -import ( - "bytes" - "fmt" - "math" - "net" - "sort" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 - fakeDefaultPrefixLen = 8 - - // fakeControlProtocol is used for control packets that represent - // destination port unreachable. - fakeControlProtocol tcpip.TransportProtocolNumber = 2 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { - if addr, err := s.GetMainNICAddress(nicID, proto); err != nil { - return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err) - } else if addr != want { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want) - } - return nil -} - -// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fakeNetworkEndpoint struct { - stack.AddressableEndpointState - - mu struct { - sync.RWMutex - - enabled bool - forwarding bool - } - - nic stack.NetworkInterface - proto *fakeNetworkProtocol - dispatcher stack.TransportDispatcher -} - -func (f *fakeNetworkEndpoint) Enable() tcpip.Error { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.enabled = true - return nil -} - -func (f *fakeNetworkEndpoint) Enabled() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.enabled -} - -func (f *fakeNetworkEndpoint) Disable() { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.enabled = false -} - -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { - if _, _, ok := f.proto.Parse(pkt); !ok { - return - } - - // Increment the received packet count in the protocol descriptor. - netHdr := pkt.NetworkHeader().View() - - dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) - addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { - return - } - addressEndpoint.DecRef() - - f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ - - // Handle control packets. - if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) - if !ok { - return - } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), - fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - // Nothing checks the error. - nil, /* transport error */ - pkt, - ) - return - } - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) -} - -func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fakeNetHeaderLen -} - -func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { - // Increment the sent packet count in the protocol descriptor. - f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++ - - // Add the protocol's header to the packet and send it to the link - // endpoint. - hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) - pkt.NetworkProtocolNumber = fakeNetNumber - hdr[dstAddrOffset] = r.RemoteAddress()[0] - hdr[srcAddrOffset] = r.LocalAddress()[0] - hdr[protocolNumberOffset] = byte(params.Protocol) - - if r.Loop()&stack.PacketLoop != 0 { - f.HandlePacket(pkt.Clone()) - } - if r.Loop()&stack.PacketOut == 0 { - return nil - } - - return f.nic.WritePacket(r, fakeNetNumber, pkt) -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { - panic("not implemented") -} - -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (f *fakeNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// Stats implements NetworkEndpoint. -func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats { - return &fakeNetworkEndpointStats{} -} - -var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil) - -type fakeNetworkEndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} - -// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the -// number of packets sent and received via endpoints of this protocol. The index -// where packets are added is given by the packet's destination address MOD 10. -type fakeNetworkProtocol struct { - packetCount [10]int - sendPacketCount [10]int - defaultTTL uint8 -} - -func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber -} - -func (*fakeNetworkProtocol) MinimumPacketSize() int { - return fakeNetHeaderLen -} - -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - -func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { - return f.packetCount[int(intfAddr)%len(f.packetCount)] -} - -func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { - e := &fakeNetworkEndpoint{ - nic: nic, - proto: f, - dispatcher: dispatcher, - } - e.AddressableEndpointState.Init(e) - return e -} - -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.DefaultTTLOption: - f.defaultTTL = uint8(*v) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.DefaultTTLOption: - *v = tcpip.DefaultTTLOption(f.defaultTTL) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -// Close implements NetworkProtocol.Close. -func (*fakeNetworkProtocol) Close() {} - -// Wait implements NetworkProtocol.Wait. -func (*fakeNetworkProtocol) Wait() {} - -// Parse implements NetworkProtocol.Parse. -func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) - if !ok { - return 0, false, false - } - pkt.NetworkProtocolNumber = fakeNetNumber - return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true -} - -// Forwarding implements stack.ForwardingNetworkEndpoint. -func (f *fakeNetworkEndpoint) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding -} - -// SetForwarding implements stack.ForwardingNetworkEndpoint. -func (f *fakeNetworkEndpoint) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} -} - -// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify -// that LinkEndpoint.Attach was called. -type linkEPWithMockedAttach struct { - stack.LinkEndpoint - attached bool -} - -// Attach implements stack.LinkEndpoint.Attach. -func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { - l.LinkEndpoint.Attach(d) - l.attached = d != nil -} - -func (l *linkEPWithMockedAttach) isAttached() bool { - return l.attached -} - -// Checks to see if list contains an address. -func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { - for _, i := range list { - if i == item { - return true - } - } - - return false -} - -func TestNetworkReceive(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and two - // addresses attached to it: 1 & 2. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Make sure packet with wrong address is not delivered. - buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to first endpoint. - buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to second endpoint. - buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet that is too small is dropped. - buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } -} - -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { - r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - return send(r, payload) -} - -func send(r *stack.Route, payload buffer.View) tcpip.Error { - return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - })) -} - -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := sendTo(s, addr, payload); err != nil { - t.Error("sendTo failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("sendTo packet count: got = %d, want %d", got, want) - } -} - -func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := send(r, payload); err != nil { - t.Error("send failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("send packet count: got = %d, want %d", got, want) - } -} - -func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { - t.Helper() - if gotErr := send(r, payload); gotErr != wantErr { - t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { - t.Helper() - if gotErr := sendTo(s, addr, payload); gotErr != wantErr { - t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we expect to receive it. - want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we do NOT expect to receive it. - want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { - t.Helper() - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got := fakeNet.PacketCount(localAddrByte); got != want { - t.Errorf("receive packet count: got = %d, want %d", got, want) - } -} - -func TestNetworkSend(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and one - // address: 1. The route table sends all packets through the only - // existing nic. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("NewNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", ep, nil) -} - -func TestNetworkSendMultiRoute(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Send a packet to an odd destination. - testSendTo(t, s, "\x05", ep1, nil) - - // Send a packet to an even destination. - testSendTo(t, s, "\x06", ep2, nil) -} - -func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - defer r.Release() - - if r.LocalAddress() != expectedSrcAddr { - t.Fatalf("got Route.LocalAddress() = %s, want = %s", expectedSrcAddr, r.LocalAddress()) - } - - if r.RemoteAddress() != dstAddr { - t.Fatalf("got Route.RemoteAddress() = %s, want = %s", dstAddr, r.RemoteAddress()) - } -} - -func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } -} - -// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to -// a NetworkDispatcher when the NIC is created. -func TestAttachToLinkEndpointImmediately(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicOpts stack.NICOptions - }{ - { - name: "Create enabled NIC", - nicOpts: stack.NICOptions{Disabled: false}, - }, - { - name: "Create disabled NIC", - nicOpts: stack.NICOptions{Disabled: true}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - - if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - }) - } -} - -func TestDisableUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - err := s.DisableNIC(1) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) - } -} - -func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := loopback.New() - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - checkNIC := func(enabled bool) { - t.Helper() - - allNICInfo := s.NICInfo() - nicInfo, ok := allNICInfo[nicID] - if !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } else if nicInfo.Flags.Running != enabled { - t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) - } - - if got := s.CheckNIC(nicID); got != enabled { - t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) - } - } - - // NIC should initially report itself as disabled. - checkNIC(false) - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - checkNIC(true) - - // If the NIC is not reporting a correct enabled status, we cannot trust the - // next check so end the test here. - if t.Failed() { - t.FailNow() - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - checkNIC(false) -} - -func TestRemoveUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - err := s.RemoveNIC(1) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) - } -} - -func TestRemoveNIC(t *testing.T) { - for _, tt := range []struct { - name string - linkep stack.LinkEndpoint - expectErr tcpip.Error - }{ - { - name: "loopback", - linkep: loopback.New(), - expectErr: &tcpip.ErrNotSupported{}, - }, - { - name: "channel", - linkep: channel.New(0, defaultMTU, ""), - expectErr: nil, - }, - } { - t.Run(tt.name, func(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: tt.linkep, - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want { - t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want) - } - if tt.expectErr == nil { - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") - } - } - }) - } -} - -func TestRouteWithDownNIC(t *testing.T) { - tests := []struct { - name string - downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error - }{ - { - name: "Disabled NIC", - downFn: (*stack.Stack).DisableNIC, - upFn: (*stack.Stack).EnableNIC, - }, - - // Once a NIC is removed, it cannot be brought up. - { - name: "Removed NIC", - downFn: (*stack.Stack).RemoveNIC, - }, - } - - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - const addr1 = tcpip.Address("\x01") - const addr2 = tcpip.Address("\x02") - const nic1Dst = tcpip.Address("\x05") - const nic2Dst = tcpip.Address("\x06") - - setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } - - return s, ep1, ep2 - } - - // Tests that routes through a down NIC are not used when looking up a route - // for a destination. - t.Run("Find", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, _, _ := setup(t) - - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) - - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Bringing NIC1 down should result in no routes to odd addresses. Routes to - // even addresses should continue to be available as NIC2 is still up. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Bringing NIC2 down should result in no routes to even addresses. No - // route should be available to any address as routes to odd addresses - // were made unavailable by bringing NIC1 down above. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - if upFn := test.upFn; upFn != nil { - // Bringing NIC1 up should make routes to odd addresses available - // again. Routes to even addresses should continue to be unavailable - // as NIC2 is still down. - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - } - }) - } - }) - - // Tests that writing a packet using a Route through a down NIC fails. - t.Run("WritePacket", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, ep1, ep2 := setup(t) - - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() - - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() - - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } - - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC1 after being brought down should fail. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC2 after being brought down should fail. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) - - if upFn := test.upFn; upFn != nil { - // Writes with Routes that use NIC1 after being brought up should - // succeed. - // - // TODO(gvisor.dev/issue/1491): Should we instead completely - // invalidate all Routes that were bound to a NIC that was brought - // down at some point? - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) - } - }) - } - }) -} - -func TestRoutes(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Test routes to odd address. - testRoute(t, s, 0, "", "\x05", "\x01") - testRoute(t, s, 0, "\x01", "\x05", "\x01") - testRoute(t, s, 1, "\x01", "\x05", "\x01") - testRoute(t, s, 0, "\x03", "\x05", "\x03") - testRoute(t, s, 1, "\x03", "\x05", "\x03") - - // Test routes to even address. - testRoute(t, s, 0, "", "\x06", "\x02") - testRoute(t, s, 0, "\x02", "\x06", "\x02") - testRoute(t, s, 2, "\x02", "\x06", "\x02") - testRoute(t, s, 0, "\x04", "\x06", "\x04") - testRoute(t, s, 2, "\x04", "\x06", "\x04") - - // Try to send to odd numbered address from even numbered ones, then - // vice-versa. - testNoRoute(t, s, 0, "\x02", "\x05") - testNoRoute(t, s, 2, "\x02", "\x05") - testNoRoute(t, s, 0, "\x04", "\x05") - testNoRoute(t, s, 2, "\x04", "\x05") - - testNoRoute(t, s, 0, "\x01", "\x06") - testNoRoute(t, s, 1, "\x01", "\x06") - testNoRoute(t, s, 0, "\x03", "\x06") - testNoRoute(t, s, 1, "\x03", "\x06") -} - -func TestAddressRemoval(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Send and receive packets, and verify they are received. - buf[dstAddrOffset] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - - // Check that removing the same address fails. - err := s.RemoveAddress(1, localAddr) - if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) - } -} - -func TestAddressRemovalWithRouteHeld(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - // Send and receive packets, and verify they are received. - buf[dstAddrOffset] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - - // Check that removing the same address fails. - { - err := s.RemoveAddress(1, localAddr) - if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) - } - } -} - -func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { - t.Helper() - info, ok := s.NICInfo()[nicID] - if !ok { - t.Fatalf("NICInfo() failed to find nicID=%d", nicID) - } - if len(addr) == 0 { - // No address given, verify that there is no address assigned to the NIC. - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) - } - } - return - } - // Address given, verify the address is assigned to the NIC and no other - // address is. - found := false - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber { - if a.AddressWithPrefix.Address == addr { - found = true - } else { - t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) - } - } - } - if !found { - t.Errorf("verify address: couldn't find %s on the NIC", addr) - } -} - -func TestEndpointExpiration(t *testing.T) { - const ( - localAddrByte byte = 0x01 - remoteAddr tcpip.Address = "\x03" - noAddr tcpip.Address = "" - nicID tcpip.NICID = 1 - ) - localAddr := tcpip.Address([]byte{localAddrByte}) - - for _, promiscuous := range []bool{true, false} { - for _, spoofing := range []bool{true, false} { - t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - buf[dstAddrOffset] = localAddrByte - - if promiscuous { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - } - - if spoofing { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - } - - // 1. No Address yet, send should only work for spoofing, receive for - // promiscuous mode. - //----------------------- - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - - // 2. Add Address, everything should work. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 3. Remove the address, send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - - // 4. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 5. Take a reference to the endpoint by getting a route. Verify that - // we can still send/receive, including sending using the route. - //----------------------- - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 6. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - - // 7. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 8. Remove the route, sendTo/recv should still work. - //----------------------- - r.Release() - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 9. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - }) - } - } -} - -func TestPromiscuousMode(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it doesn't get delivered as we don't - // have a matching endpoint. - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - - // Set promiscuous mode, then check that packet is delivered. - if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - - // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } - - // Set promiscuous mode to false, then check that packet can't be - // delivered anymore. - if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -// TestExternalSendWithHandleLocal tests that the stack creates a non-local -// route when spoofing or promiscuous mode are enabled. -// -// This test makes sure that packets are transmitted from the stack. -func TestExternalSendWithHandleLocal(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - - localAddr = tcpip.Address("\x01") - dstAddr = tcpip.Address("\x03") - ) - - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - configureStack func(*testing.T, *stack.Stack) - }{ - { - name: "Default", - configureStack: func(*testing.T, *stack.Stack) {}, - }, - { - name: "Spoofing", - configureStack: func(t *testing.T, s *stack.Stack) { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) - } - }, - }, - { - name: "Promiscuous", - configureStack: func(t *testing.T, s *stack.Stack) { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, handleLocal := range []bool{true, false} { - t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - HandleLocal: handleLocal, - }) - - ep := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) - - test.configureStack(t, s) - - r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) - } - defer r.Release() - - if r.LocalAddress() != localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), localAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - - if n := ep.Drain(); n != 0 { - t.Fatalf("got ep.Drain() = %d, want = 0", n) - } - if err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: fakeTransNumber, - TTL: 123, - TOS: stack.DefaultTOS, - }, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.NewView(10).ToVectorisedView(), - })); err != nil { - t.Fatalf("r.WritePacket(nil, _, _): %s", err) - } - if n := ep.Drain(); n != 1 { - t.Fatalf("got ep.Drain() = %d, want = 1", n) - } - }) - } - }) - } -} - -func TestSpoofingWithAddress(t *testing.T) { - localAddr := tcpip.Address("\x01") - nonExistentLocalAddr := tcpip.Address("\x02") - dstAddr := tcpip.Address("\x03") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress() != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - // Sending a packet works. - testSendTo(t, s, dstAddr, ep, nil) - testSend(t, r, ep, nil) - - // FindRoute should also work with a local address that exists on the NIC. - r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress() != localAddr { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - // Sending a packet using the route works. - testSend(t, r, ep, nil) -} - -func TestSpoofingNoAddress(t *testing.T) { - nonExistentLocalAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress() != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - // Sending a packet works. - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) -} - -func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - s.SetRouteTable([]tcpip.Route{}) - - // If there is no endpoint, it won't work. - { - _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) - } - } - - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) - } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != header.IPv4Any { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), header.IPv4Any) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } - - // If the NIC doesn't exist, it won't work. - { - _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) - } - } -} - -func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} - // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} - nic1Gateway := testutil.MustParse4("192.168.1.1") - // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} - nic2Gateway := testutil.MustParse4("10.10.10.1") - - // Create a new stack with two NICs. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - if err := s.CreateNIC(2, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) - } - - nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) - } - - // Set the initial route table. - rt := []tcpip.Route{ - {Destination: nic1Addr.Subnet(), NIC: 1}, - {Destination: nic2Addr.Subnet(), NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, - } - s.SetRouteTable(rt) - - // When an interface is given, the route for a broadcast goes through it. - r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != nic1Addr.Address { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic1Addr.Address) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } - - // When an interface is not given, it consults the route table. - // 1. Case: Using the default route. - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != nic2Addr.Address { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic2Addr.Address) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } - - // 2. Case: Having an explicit route for broadcast will select that one. - rt = append( - []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, - }, - rt..., - ) - s.SetRouteTable(rt) - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != nic1Addr.Address { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic1Addr.Address) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } -} - -func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { - for _, tc := range []struct { - name string - routeNeeded bool - address tcpip.Address - }{ - // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255 - // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff - {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"}, - {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"}, - {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"}, - {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"}, - {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"}, - - // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112] - {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 link-local address starts with fe80::/10. - {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"}, - {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 addresses that are neither multicast nor link-local. - {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - } { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - s.SetRouteTable([]tcpip.Route{}) - - var anyAddr tcpip.Address - if len(tc.address) == header.IPv4AddressSize { - anyAddr = header.IPv4Any - } else { - anyAddr = header.IPv6Any - } - - var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} - if tc.routeNeeded { - want = &tcpip.ErrNoRoute{} - } - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) - } - - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { - // Route table is empty but we need a route, this should cause an error. - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) - } - } else { - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err) - } - if r.LocalAddress() != anyAddr { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress(), anyAddr) - } - if r.RemoteAddress() != tc.address { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress(), tc.address) - } - } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - }) - } -} - -func TestNetworkOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{}, - }) - - opt := tcpip.DefaultTTLOption(5) - if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { - t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) - } - - var optGot tcpip.DefaultTTLOption - if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { - t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) - } - - if opt != optGot { - t.Errorf("got optGot = %d, want = %d", optGot, opt) - } -} - -func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { - const nicID = 1 - - for _, addrLen := range []int{4, 16} { - t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { - for canBe := 0; canBe < 3; canBe++ { - t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { - for never := 0; never < 3; never++ { - t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - // Insert <canBe> primary and <never> never-primary addresses. - // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) - for i := 0; i < canBe+never; i++ { - var behavior stack.PrimaryEndpointBehavior - if i < canBe { - behavior = stack.CanBePrimaryEndpoint - } else { - behavior = stack.NeverPrimaryEndpoint - } - // Add an address and in case of a primary one include a - // prefixLen. - address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if behavior == stack.CanBePrimaryEndpoint { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) - } - // Remember the address/prefix. - primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} - } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) - } - } - } - // Check that GetMainNICAddress returns an address if at least - // one primary address was added. In that case make sure the - // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } - if len(primaryAddrAdded) == 0 { - // No primary addresses present. - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr) - } - } else { - // At least one primary address was added, verify the returned - // address is in the list of primary addresses we added. - if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded) - } - } - }) - } - }) - } - }) - } -} - -func TestGetMainNICAddressErrors(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Sanity check with a successful call. - if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil { - t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want) - } - - const unknownNICID = nicID + 1 - switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) { - case *tcpip.ErrUnknownNICID: - default: - t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err) - } - - // ARP is not an addressable network endpoint. - switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) { - case *tcpip.ErrNotSupported: - default: - t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err) - } - - const unknownProtocolNumber = 1234 - switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) { - case *tcpip.ErrUnknownProtocol: - default: - t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err) - } -} - -func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - for _, tc := range []struct { - name string - address tcpip.Address - prefixLen int - }{ - {"IPv4", "\x01\x01\x01\x01", 24}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, - } { - t.Run(tc.name, func(t *testing.T) { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tc.address, - PrefixLen: tc.prefixLen, - }, - } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) - } - - // Check that we get the right initial address and prefix length. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// Simple network address generator. Good for 255 addresses. -type addressGenerator struct{ cnt byte } - -func (g *addressGenerator) next(addrLen int) tcpip.Address { - g.cnt++ - return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) -} - -func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { - t.Helper() - - if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) - } - - sort.Slice(gotAddresses, func(i, j int) bool { - return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address - }) - sort.Slice(expectedAddresses, func(i, j int) bool { - return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address - }) - - for i, gotAddr := range gotAddresses { - expectedAddr := expectedAddresses[i] - if gotAddr != expectedAddr { - t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) - } - } -} - -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestCreateNICWithOptions(t *testing.T) { - type callArgsAndExpect struct { - nicID tcpip.NICID - opts stack.NICOptions - err tcpip.Error - } - - tests := []struct { - desc string - calls []callArgsAndExpect - }{ - { - desc: "DuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth1"}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth2"}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - { - desc: "DuplicateName", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "lo"}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{Name: "lo"}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - { - desc: "Unnamed", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{}, - err: nil, - }, - }, - }, - { - desc: "UnnamedDuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - s := stack.New(stack.Options{}) - ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") - for _, call := range test.calls { - if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { - t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) - } - } - }) - } -} - -func TestNICStats(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - nics := []struct { - addr tcpip.Address - txByteCount int - rxByteCount int - }{ - { - addr: "\x01", - txByteCount: 30, - rxByteCount: 10, - }, - { - addr: "\x02", - txByteCount: 50, - rxByteCount: 20, - }, - } - - var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int - for i, nic := range nics { - nicid := tcpip.NICID(i) - ep := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet(nic.addr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) - } - - nicStats := s.NICInfo()[nicid].Stats - - // Inbound packet. - rxBuffer := buffer.NewView(nic.rxByteCount) - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: rxBuffer.ToVectorisedView(), - })) - if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } - if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) - } - rxPacketsTotal++ - rxBytesTotal += nic.rxByteCount - - // Outbound packet. - txBuffer := buffer.NewView(nic.txByteCount) - actualTxLength := nic.txByteCount + fakeNetHeaderLen - if err := sendTo(s, nic.addr, txBuffer); err != nil { - t.Fatal("sendTo failed: ", err) - } - want := ep.Drain() - if got := nicStats.Tx.Packets.Value(); got != uint64(want) { - t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) - } - if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - txPacketsTotal += want - txBytesTotal += actualTxLength - } - - // Now verify that each NIC stats was correctly aggregated at the stack level. - if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { - t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) - } - if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { - t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) - } - if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) - } - if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -// TestNICContextPreservation tests that you can read out via stack.NICInfo the -// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. -func TestNICContextPreservation(t *testing.T) { - var ctx *int - tests := []struct { - name string - opts stack.NICOptions - want stack.NICContext - }{ - { - "context_set", - stack.NICOptions{Context: ctx}, - ctx, - }, - { - "context_not_set", - stack.NICOptions{}, - nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - id := tcpip.NICID(1) - ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") - if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { - t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) - } - nicinfos := s.NICInfo() - nicinfo, ok := nicinfos[id] - if !ok { - t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) - } - if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) - } - }) - } -} - -// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local -// addresses. -func TestNICAutoGenLinkLocalAddr(t *testing.T) { - const nicID = 1 - - var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte - n, err := rand.Read(secretKey[:]) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) - } - - nicNameFunc := func(_ tcpip.NICID, name string) string { - return name - } - - tests := []struct { - name string - nicName string - autoGen bool - linkAddr tcpip.LinkAddress - iidOpts ipv6.OpaqueInterfaceIdentifierOptions - shouldGen bool - expectedAddr tcpip.Address - }{ - { - name: "Disabled", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - shouldGen: false, - }, - { - name: "Disabled without OIID options", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: false, - }, - - // Tests for EUI64 based addresses. - { - name: "EUI64 Enabled", - autoGen: true, - linkAddr: linkAddr1, - shouldGen: true, - expectedAddr: header.LinkLocalAddr(linkAddr1), - }, - { - name: "EUI64 Empty MAC", - autoGen: true, - shouldGen: false, - }, - { - name: "EUI64 Invalid MAC", - autoGen: true, - linkAddr: "\x01\x02\x03", - shouldGen: false, - }, - { - name: "EUI64 Multicast MAC", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - shouldGen: false, - }, - { - name: "EUI64 Unspecified MAC", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - shouldGen: false, - }, - - // Tests for Opaque IID based addresses. - { - name: "OIID Enabled", - nicName: "nic1", - autoGen: true, - linkAddr: linkAddr1, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), - }, - // These are all cases where we would not have generated a - // link-local address if opaque IIDs were disabled. - { - name: "OIID Empty MAC and empty nicName", - autoGen: true, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:1], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), - }, - { - name: "OIID Invalid MAC", - nicName: "test", - autoGen: true, - linkAddr: "\x01\x02\x03", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:2], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), - }, - { - name: "OIID Multicast MAC", - nicName: "test2", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:3], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), - }, - { - name: "OIID Unspecified MAC and nil SecretKey", - nicName: "test3", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, - })}, - } - - e := channel.New(0, 1280, test.linkAddr) - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - // A new disabled NIC should not have any address, even if auto generation - // was enabled. - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) - } - - // Enabling the NIC should attempt auto-generation of a link-local - // address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - var expectedMainAddr tcpip.AddressWithPrefix - if test.shouldGen { - expectedMainAddr = tcpip.AddressWithPrefix{ - Address: test.expectedAddr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - } - - // Should have auto-generated an address and resolved immediately (DAD - // is disabled). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } else { - // Should not have auto-generated an address. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address") - default: - } - } - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { - t.Fatal(err) - } - - // Disabling the NIC should remove the auto-generated address. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are -// not auto-generated for loopback NICs. -func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { - const nicID = 1 - const nicName = "nicName" - - tests := []struct { - name string - opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions - }{ - { - name: "IID From MAC", - opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, - }, - { - name: "Opaque IID", - opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, - })}, - } - - e := loopback.New() - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - dadConfigs := stack.DefaultDADConfigurations() - clock := faketime.NewManualClock() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - Clock: clock, - } - - e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { - t.Fatal(err) - } -} - -// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected -// when an address's kind gets "promoted" to permanent from permanentExpired. -func TestNewPEBOnPromotionToPermanent(t *testing.T) { - const nicID = 1 - - pebs := []stack.PrimaryEndpointBehavior{ - stack.NeverPrimaryEndpoint, - stack.CanBePrimaryEndpoint, - stack.FirstPrimaryEndpoint, - } - - for _, pi := range pebs { - for _, ps := range pebs { - t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Add a permanent address with initial - // PrimaryEndpointBehavior (peb), pi. If pi is - // NeverPrimaryEndpoint, the address should not - // be returned by a call to GetMainNICAddress; - // else, it should. - const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) - } - addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } - if pi == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) - - } - } else if addr.Address != address1 { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Take a route through the address so its ref - // count gets incremented and does not actually - // get deleted when RemoveAddress is called - // below. This is because we want to test that a - // new peb is respected when an address gets - // "promoted" to permanent from a - // permanentExpired kind. - const address2 = tcpip.Address("\x02") - r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, address1); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err) - } - - // - // At this point, the address should still be - // known by the NIC, but have its - // kind = permanentExpired. - // - - // Add some other address with peb set to - // FirstPrimaryEndpoint. - const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - - } - - // Add back the address we removed earlier and - // make sure the new peb was respected. - // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) - } - var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { - primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) - } - var expectedList []tcpip.Address - switch ps { - case stack.FirstPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x01", - "\x03", - } - case stack.CanBePrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - "\x01", - } - case stack.NeverPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - } - } - if !cmp.Equal(primaryAddrs, expectedList) { - t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) - } - - // Once we remove the other address, if the new - // peb, ps, was NeverPrimaryEndpoint, no address - // should be returned by a call to - // GetMainNICAddress; else, our original address - // should be returned. - if err := s.RemoveAddress(nicID, address3); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err) - } - addr, err = s.GetMainNICAddress(nicID, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } - if ps == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) - } - } else { - if addr.Address != address1 { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) - } - } - }) - } - } -} - -func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { - const ( - nicID = 1 - lifetimeSeconds = 9999 - ) - - var ( - linkLocalAddr1 = testutil.MustParse6("fe80::1") - linkLocalAddr2 = testutil.MustParse6("fe80::2") - linkLocalMulticastAddr = testutil.MustParse6("ff02::1") - uniqueLocalAddr1 = testutil.MustParse6("fc00::1") - uniqueLocalAddr2 = testutil.MustParse6("fd00::2") - globalAddr1 = testutil.MustParse6("a000::1") - globalAddr2 = testutil.MustParse6("a000::2") - globalAddr3 = testutil.MustParse6("a000::3") - ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1") - ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2") - toredoAddr1 = testutil.MustParse6("2001::1") - toredoAddr2 = testutil.MustParse6("2001::2") - ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1") - ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2") - ) - - prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) - - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address - tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address - - // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. - tests := []struct { - name string - slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix - nicAddrs []tcpip.Address - slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - remoteAddr tcpip.Address - expectedLocalAddr tcpip.Address - }{ - // Test Rule 1 of RFC 6724 section 5 (prefer same address). - { - name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, - remoteAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, - remoteAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, - remoteAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - - // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). - { - name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - - // Test Rule 6 of 6724 section 5 (prefer matching label). - { - name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Toredo most preferred (first address)", - nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, - remoteAddr: toredoAddr2, - expectedLocalAddr: toredoAddr1, - }, - { - name: "Toredo most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, - remoteAddr: toredoAddr2, - expectedLocalAddr: toredoAddr1, - }, - { - name: "6To4 most preferred (first address)", - nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, - remoteAddr: ipv6ToIPv4Addr2, - expectedLocalAddr: ipv6ToIPv4Addr1, - }, - { - name: "6To4 most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, - remoteAddr: ipv6ToIPv4Addr2, - expectedLocalAddr: ipv6ToIPv4Addr1, - }, - { - name: "IPv4 mapped IPv6 most preferred (first address)", - nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, - remoteAddr: ipv4MappedIPv6Addr2, - expectedLocalAddr: ipv4MappedIPv6Addr1, - }, - { - name: "IPv4 mapped IPv6 most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, - remoteAddr: ipv4MappedIPv6Addr2, - expectedLocalAddr: ipv4MappedIPv6Addr1, - }, - - // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). - { - name: "Temp Global most preferred (last address)", - slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: tempGlobalAddr1, - }, - { - name: "Temp Global most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - remoteAddr: globalAddr2, - expectedLocalAddr: tempGlobalAddr1, - }, - - // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). - { - name: "Longest prefix matched most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, - remoteAddr: globalAddr3, - expectedLocalAddr: globalAddr2, - }, - { - name: "Longest prefix matched most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, - remoteAddr: globalAddr3, - expectedLocalAddr: globalAddr2, - }, - - // Test returning the endpoint that is closest to the front when - // candidate addresses are "equal" from the perspective of RFC 6724 - // section 5. - { - name: "Unique Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - remoteAddr: globalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Link Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - remoteAddr: globalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local for Unique Local", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Temp Global for Global", - slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, - slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - remoteAddr: globalAddr1, - expectedLocalAddr: tempGlobalAddr2, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) - } - - for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) - } - } - - if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) { - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) - } - - if t.Failed() { - t.FailNow() - } - - netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } - - addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) - if !ok { - t.Fatal("network endpoint is not addressable") - } - - addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) - if addressEP == nil { - t.Fatal("expected a non-nil address endpoint") - } - defer addressEP.DecRef() - - if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { - t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) - } - }) - } -} - -func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { - const nicID = 1 - broadcastAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) - } - } - - // Enabling the NIC should add the IPv4 broadcast address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if !containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) - } - } - - // Disabling the NIC should remove the IPv4 broadcast address. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) - } - } -} - -// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 -// address after leaving its solicited node multicast address does not result in -// an error. -func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) - } - - // The NIC should have joined addr1's solicited node multicast address. - snmc := header.SolicitedNodeAddr(addr1) - in, err := s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if !in { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) - } - - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) - } - in, err = s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if in { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) - } - - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } -} - -func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - { - name: "IPv6 All-Nodes", - proto: header.IPv6ProtocolNumber, - addr: header.IPv6AllNodesMulticastAddress, - }, - { - name: "IPv4 All-Systems", - proto: header.IPv4ProtocolNumber, - addr: header.IPv4AllSystems, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - // Should not be in the multicast group yet because the NIC has not been - // enabled yet. - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) - } - - // The multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) - } - - // Leaving the group before disabling the NIC should not cause an error. - if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { - t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - }) - } -} - -// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC -// was disabled have DAD performed on them when the NIC is enabled. -func TestDoDADWhenNICEnabled(t *testing.T) { - const dadTransmits = 1 - const retransmitTimer = time.Second - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - clock := faketime.NewManualClock() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - } - - e := channel.New(dadTransmits, 1280, linkAddr1) - s := stack.New(opts) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 128, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - // Address should be in the list of all addresses. - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should be tentative so it should not be a main address. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Enabling the NIC should start DAD for the address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Wait for DAD to resolve. - clock.Advance(dadTransmits * retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD resolution") - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - // Enabling the NIC again should be a no-op. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { - t.Fatal(err) - } -} - -func TestStackReceiveBufferSizeOption(t *testing.T) { - const sMin = stack.MinBufferSize - testCases := []struct { - name string - rs tcpip.ReceiveBufferSizeOption - err tcpip.Error - }{ - // Invalid configurations. - {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - - // Valid Configurations - {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - defer s.Close() - if err := s.SetOption(tc.rs); err != tc.err { - t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) - } - var rs tcpip.ReceiveBufferSizeOption - if tc.err == nil { - if err := s.Option(&rs); err != nil { - t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) - } - if got, want := rs, tc.rs; got != want { - t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) - } - } - }) - } -} - -func TestStackSendBufferSizeOption(t *testing.T) { - const sMin = stack.MinBufferSize - testCases := []struct { - name string - ss tcpip.SendBufferSizeOption - err tcpip.Error - }{ - // Invalid configurations. - {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - - // Valid Configurations - {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - defer s.Close() - err := s.SetOption(tc.ss) - if diff := cmp.Diff(tc.err, err); diff != "" { - t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) - } - if tc.err == nil { - var ss tcpip.SendBufferSizeOption - if err := s.Option(&ss); err != nil { - t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) - } - if got, want := ss, tc.ss; got != want { - t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) - } - } - }) - } -} - -func TestOutgoingSubnetBroadcast(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID1 = 1 - ) - - defaultAddr := tcpip.AddressWithPrefix{ - Address: header.IPv4Any, - PrefixLen: 0, - } - defaultSubnet := defaultAddr.Subnet() - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := testutil.MustParse4("192.168.1.1") - ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 31, - } - ipv4Subnet31 := ipv4AddrPrefix31.Subnet() - ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() - ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 32, - } - ipv4Subnet32 := ipv4AddrPrefix32.Subnet() - ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - remNetAddr := tcpip.AddressWithPrefix{ - Address: "\x64\x0a\x7b\x18", - PrefixLen: 24, - } - remNetSubnet := remNetAddr.Subnet() - remNetSubnetBcast := remNetSubnet.Broadcast() - - tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedLocalAddress tcpip.Address - expectedRemoteAddress tcpip.Address - expectedRemoteLinkAddress tcpip.LinkAddress - expectedNextHop tcpip.Address - expectedNetProto tcpip.NetworkProtocolNumber - expectedLoop stack.PacketLooping - }{ - // Broadcast to a locally attached subnet populates the broadcast MAC. - { - name: "IPv4 Broadcast to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv4SubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: ipv4SubnetBcast, - expectedRemoteLinkAddress: header.EthernetBroadcastAddress, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut | stack.PacketLoop, - }, - // Broadcast to a locally attached /31 subnet does not populate the - // broadcast MAC. - { - name: "IPv4 Broadcast to local /31 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix31, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet31, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet31Bcast, - expectedLocalAddress: ipv4AddrPrefix31.Address, - expectedRemoteAddress: ipv4Subnet31Bcast, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to a locally attached /32 subnet does not populate the - // broadcast MAC. - { - name: "IPv4 Broadcast to local /32 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix32, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet32, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet32Bcast, - expectedLocalAddress: ipv4AddrPrefix32.Address, - expectedRemoteAddress: ipv4Subnet32Bcast, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // IPv6 has no notion of a broadcast. - { - name: "IPv6 'Broadcast' to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: ipv6Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv6Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv6SubnetBcast, - expectedLocalAddress: ipv6Addr.Address, - expectedRemoteAddress: ipv6SubnetBcast, - expectedNetProto: header.IPv6ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to a remote subnet in the route table is send to the next-hop - // gateway. - { - name: "IPv4 Broadcast to remote subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: remNetSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: remNetSubnetBcast, - expectedNextHop: ipv4Gateway, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to an unknown subnet follows the default route. Note that this - // is essentially just routing an unknown destination IP, because w/o any - // subnet prefix information a subnet broadcast address is just a normal IP. - { - name: "IPv4 Broadcast to unknown subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: defaultSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: remNetSubnetBcast, - expectedNextHop: ipv4Gateway, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID1, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) - } - - s.SetRouteTable(test.routes) - - var netProto tcpip.NetworkProtocolNumber - switch l := len(test.remoteAddr); l { - case header.IPv4AddressSize: - netProto = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - netProto = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } - if r.LocalAddress() != test.expectedLocalAddress { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.expectedLocalAddress) - } - if r.RemoteAddress() != test.expectedRemoteAddress { - t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress(), test.expectedRemoteAddress) - } - if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { - t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) - } - if r.NextHop() != test.expectedNextHop { - t.Errorf("got r.NextHop() = %s, want = %s", r.NextHop(), test.expectedNextHop) - } - if r.NetProto() != test.expectedNetProto { - t.Errorf("got r.NetProto() = %d, want = %d", r.NetProto(), test.expectedNetProto) - } - if r.Loop() != test.expectedLoop { - t.Errorf("got r.Loop() = %x, want = %x", r.Loop(), test.expectedLoop) - } - }) - } -} - -func TestResolveWith(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address([]byte{192, 168, 1, 58}), - PrefixLen: 24, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) - - remoteAddr := tcpip.Address([]byte{192, 168, 1, 59}) - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) - } - defer r.Release() - - // Should initially require resolution. - if !r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = false, want = true") - } - - // Manually resolving the route should no longer require resolution. - r.ResolveWith("\x01") - if r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = true, want = false") - } -} - -// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its -// associated address is removed should not cause a panic. -func TestRouteReleaseAfterAddrRemoval(t *testing.T) { - const ( - nicID = 1 - localAddr = tcpip.Address("\x01") - remoteAddr = tcpip.Address("\x02") - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) - } - // Should not panic. - defer r.Release() - - // Check that removing the same address fails. - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) - } -} - -func TestGetNetworkEndpoint(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - }, - } - - factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) - for _, test := range tests { - factories = append(factories, test.protoFactory) - } - - s := stack.New(stack.Options{ - NetworkProtocols: factories, - }) - - if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) - } - - if got := ep.NetworkProtocolNumber(); got != test.protoNum { - t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) - } - }) - } -} - -func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: "\x01", - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) - } - - // Check that we get the right initial address and prefix length. - if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - // Should still get the address when the NIC is diabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } -} - -// TestAddRoute tests Stack.AddRoute -func TestAddRoute(t *testing.T) { - s := stack.New(stack.Options{}) - - subnet1, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - - subnet2, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - - expected := []tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet2, Gateway: "\x00", NIC: 1}, - } - - // Initialize the route table with one route. - s.SetRouteTable([]tcpip.Route{expected[0]}) - - // Add another route. - s.AddRoute(expected[1]) - - rt := s.GetRouteTable() - if got, want := len(rt), len(expected); got != want { - t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) - } - for i, route := range rt { - if got, want := route, expected[i]; got != want { - t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) - } - } -} - -// TestRemoveRoutes tests Stack.RemoveRoutes -func TestRemoveRoutes(t *testing.T) { - s := stack.New(stack.Options{}) - - addressToRemove := tcpip.Address("\x01") - subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") - if err != nil { - t.Fatal(err) - } - - subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") - if err != nil { - t.Fatal(err) - } - - subnet3, err := tcpip.NewSubnet("\x02", "\x02") - if err != nil { - t.Fatal(err) - } - - // Initialize the route table with three routes. - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet2, Gateway: "\x00", NIC: 1}, - {Destination: subnet3, Gateway: "\x00", NIC: 1}, - }) - - // Remove routes with the specific address. - s.RemoveRoutes(func(r tcpip.Route) bool { - return r.Destination.ID() == addressToRemove - }) - - expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} - rt := s.GetRouteTable() - if got, want := len(rt), len(expected); got != want { - t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) - } - for i, route := range rt { - if got, want := route, expected[i]; got != want { - t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) - } - } -} - -func TestFindRouteWithForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Addr = tcpip.Address("\x01") - nic2Addr = tcpip.Address("\x02") - remoteAddr = tcpip.Address("\x03") - ) - - type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address - } - - fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, - } - - globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) - globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) - - ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, - } - ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, - } - ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - } - - tests := []struct { - name string - - netCfg netCfg - forwardingEnabled bool - - addrNIC tcpip.NICID - localAddr tcpip.Address - - findRouteErr tcpip.Error - dependentOnForwarding bool - }{ - { - name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: nil, - dependentOnForwarding: true, - }, - { - name: "forwarding disabled and localAddr on specified NIC and route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on specified NIC and route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on same NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on same NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on different NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on different NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: nil, - dependentOnForwarding: true, - }, - { - name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: false, - addrNIC: nicID1, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - addrNIC: nicID1, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and link-local local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and link-local local addr with route on same NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with route on same NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and link-local local addr with route on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and link-local local addr with route on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) - } - - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) - } - - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if err == nil { - defer r.Release() - } - if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) - } - - if test.findRouteErr != nil { - return - } - - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) - } - if r.RemoteAddress() != test.netCfg.remoteAddr { - t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) - } - - if t.Failed() { - t.FailNow() - } - - // Sending a packet should always go through NIC2 since we only install a - // route to test.netCfg.remoteAddr through NIC2. - data := buffer.View([]byte{1, 2, 3, 4}) - if err := send(r, data); err != nil { - t.Fatalf("send(_, _): %s", err) - } - if n := ep1.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep1", n) - } - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not sent through ep2") - } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) - } - if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) - } - - if !test.forwardingEnabled || !test.dependentOnForwarding { - return - } - - // Disabling forwarding when the route is dependent on forwarding being - // enabled should make the route invalid. - if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) - } - { - err := send(r, data) - if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) - } - } - if n := ep1.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep1", n) - } - if n := ep2.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep2", n) - } - }) - } -} - -func TestWritePacketToRemote(t *testing.T) { - const nicID = 1 - const MTU = 1280 - e := channel.New(1, MTU, linkAddr1) - s := stack.New(stack.Options{}) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("CreateNIC(%d) = %s", nicID, err) - } - tests := []struct { - name string - protocol tcpip.NetworkProtocolNumber - payload []byte - }{ - { - name: "SuccessIPv4", - protocol: header.IPv4ProtocolNumber, - payload: []byte{1, 2, 3, 4}, - }, - { - name: "SuccessIPv6", - protocol: header.IPv6ProtocolNumber, - payload: []byte{5, 6, 7, 8}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) - } - - pkt, ok := e.Read() - if got, want := ok, true; got != want { - t.Fatalf("e.Read() = %t, want %t", got, want) - } - if got, want := pkt.Proto, test.protocol; got != want { - t.Fatalf("pkt.Proto = %d, want %d", got, want) - } - if pkt.Route.RemoteLinkAddress != linkAddr2 { - t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) - } - if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" { - t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) - } - }) - } - - t.Run("InvalidNICID", func(t *testing.T) { - err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) - if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) - } - pkt, ok := e.Read() - if got, want := ok, false; got != want { - t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) - } - }) -} - -func TestClearNeighborCacheOnNICDisable(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - ) - - var ( - ipv4Addr = testutil.MustParse4("1.2.3.4") - ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304") - ) - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - }) - e := channel.New(0, 0, "") - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrs := []struct { - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - { - proto: ipv4.ProtocolNumber, - addr: ipv4Addr, - }, - { - proto: ipv6.ProtocolNumber, - addr: ipv6Addr, - }, - } - for _, addr := range addrs { - if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) - } - - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, - neighbors, - ); diff != "" { - t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) - } - } - - // Disabling the NIC should clear the neighbor table. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - for _, addr := range addrs { - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if len(neighbors) != 0 { - t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) - } - } - - // Enabling the NIC should have an empty neighbor table. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - for _, addr := range addrs { - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if len(neighbors) != 0 { - t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) - } - } -} - -func TestGetLinkAddressErrors(t *testing.T) { - const ( - nicID = 1 - unknownNICID = nicID + 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - { - err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) - } - } - { - err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) - } - } -} - -func TestStaticGetLinkAddress(t *testing.T) { - const ( - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - }) - e := channel.New(0, 0, "") - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4", - proto: ipv4.ProtocolNumber, - addr: header.IPv4Broadcast, - expectedLinkAddr: header.EthernetBroadcastAddress, - }, - { - name: "IPv6", - proto: ipv6.ProtocolNumber, - addr: header.IPv6AllNodesMulticastAddress, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ch := make(chan stack.LinkResolutionResult, 1) - if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { - ch <- r - }); err != nil { - t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) - } - - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Err: nil}, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/tcpip/stack/stack_unsafe_state_autogen.go b/pkg/tcpip/stack/stack_unsafe_state_autogen.go new file mode 100644 index 000000000..758ab3457 --- /dev/null +++ b/pkg/tcpip/stack/stack_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package stack diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go deleted file mode 100644 index 45b09110d..000000000 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ /dev/null @@ -1,446 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "io/ioutil" - "math" - "math/rand" - "strconv" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - testSrcAddrV4 = "\x0a\x00\x00\x01" - testDstAddrV4 = "\x0a\x00\x00\x02" - - testDstPort = 1234 - testSrcPort = 4096 -) - -type testContext struct { - linkEps map[tcpip.NICID]*channel.Endpoint - s *stack.Stack - wq waiter.Queue -} - -// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. -func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - linkEps := make(map[tcpip.NICID]*channel.Endpoint) - for _, linkEpID := range linkEpIDs { - channelEp := channel.New(256, mtu, "") - if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - linkEps[linkEpID] = channelEp - - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) - } - - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) - } - } - - s.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: 1}, - {Destination: header.IPv6EmptySubnet, NIC: 1}, - }) - - return &testContext{ - s: s, - linkEps: linkEps, - } -} - -type headers struct { - srcPort uint16 - dstPort uint16 -} - -func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TOS: 0x80, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testSrcAddrV4, - DstAddr: testDstAddrV4, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) -} - -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) -} - -func TestTransportDemuxerRegister(t *testing.T) { - for _, test := range []struct { - name string - proto tcpip.NetworkProtocolNumber - want tcpip.Error - }{ - {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, - {"success", ipv4.ProtocolNumber, nil}, - } { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - tEP, ok := ep.(stack.TransportEndpoint) - if !ok { - t.Fatalf("%T does not implement stack.TransportEndpoint", ep) - } - if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) - } - }) - } -} - -func TestTransportDemuxerRegisterMultiple(t *testing.T) { - type test struct { - flags ports.Flags - want tcpip.Error - } - for _, subtest := range []struct { - name string - tests []test - }{ - {"zeroFlags", []test{ - {ports.Flags{}, nil}, - {ports.Flags{}, &tcpip.ErrPortInUse{}}, - }}, - {"multibindFlags", []test{ - // Allow multiple registrations same TransportEndpointID with multibind flags. - {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, - {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, - // Disallow registration w/same ID for a non-multibindflag. - {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, - }}, - } { - t.Run(subtest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - var eps []tcpip.Endpoint - for idx, test := range subtest.tests { - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - eps = append(eps, ep) - tEP, ok := ep.(stack.TransportEndpoint) - if !ok { - t.Fatalf("%T does not implement stack.TransportEndpoint", ep) - } - id := stack.TransportEndpointID{LocalPort: 1} - if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { - t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) - } - } - for _, ep := range eps { - ep.Close() - } - }) - } -} - -// TestBindToDeviceDistribution injects varied packets on input devices and checks that -// the distribution of packets received matches expectations. -func TestBindToDeviceDistribution(t *testing.T) { - type endpointSockopts struct { - reuse bool - bindToDevice tcpip.NICID - } - tcs := []struct { - name string - // endpoints will received the inject packets. - endpoints []endpointSockopts - // wantDistributions is the want ratio of packets received on each - // endpoint for each NIC on which packets are injected. - wantDistributions map[tcpip.NICID][]float64 - }{ - { - name: "BindPortReuse", - // 5 endpoints that all have reuse set. - endpoints: []endpointSockopts{ - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - }, - wantDistributions: map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed evenly. - 1: {0.2, 0.2, 0.2, 0.2, 0.2}, - }, - }, - { - name: "BindToDevice", - // 3 endpoints with various bindings. - endpoints: []endpointSockopts{ - {reuse: false, bindToDevice: 1}, - {reuse: false, bindToDevice: 2}, - {reuse: false, bindToDevice: 3}, - }, - wantDistributions: map[tcpip.NICID][]float64{ - // Injected packets on dev0 go only to the endpoint bound to dev0. - 1: {1, 0, 0}, - // Injected packets on dev1 go only to the endpoint bound to dev1. - 2: {0, 1, 0}, - // Injected packets on dev2 go only to the endpoint bound to dev2. - 3: {0, 0, 1}, - }, - }, - { - name: "ReuseAndBindToDevice", - // 6 endpoints with various bindings. - endpoints: []endpointSockopts{ - {reuse: true, bindToDevice: 1}, - {reuse: true, bindToDevice: 1}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 0}, - }, - wantDistributions: map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed among endpoints bound to - // dev0. - 1: {0.5, 0.5, 0, 0, 0, 0}, - // Injected packets on dev1 get distributed among endpoints bound to - // dev1 or unbound. - 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, - // Injected packets on dev999 go only to the unbound. - 1000: {0, 0, 0, 0, 0, 1}, - }, - }, - } - protos := map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } - - for _, test := range tcs { - for protoName, protoNum := range protos { - for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { - // Create the NICs. - var devices []tcpip.NICID - for d := range test.wantDistributions { - devices = append(devices, d) - } - c := newDualTestContextMultiNIC(t, defaultMTU, devices) - - // Create endpoints and bind each to a NIC, sometimes reusing ports. - eps := make(map[tcpip.Endpoint]int) - pollChannel := make(chan tcpip.Endpoint) - for i, endpoint := range test.endpoints { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - t.Cleanup(func() { - wq.EventUnregister(&we) - close(ch) - }) - - var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - t.Cleanup(ep.Close) - eps[ep] = i - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(ep) - - ep.SocketOptions().SetReusePort(endpoint.reuse) - if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) - } - - var dstAddr tcpip.Address - switch protoNum { - case ipv4.ProtocolNumber: - dstAddr = testDstAddrV4 - case ipv6.ProtocolNumber: - dstAddr = testDstAddrV6 - default: - t.Fatalf("unexpected protocol number: %d", protoNum) - } - if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) - } - } - - // Send packets across a range of ports, checking that packets from - // the same source port are always demultiplexed to the same - // destination endpoint. - npackets := 10_000 - nports := 1_000 - if got, want := len(test.endpoints), len(wantDistribution); got != want { - t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) - } - endpoints := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - hdrs := &headers{ - srcPort: testSrcPort + port, - dstPort: testDstPort, - } - switch protoNum { - case ipv4.ProtocolNumber: - c.sendV4Packet(payload, hdrs, device) - case ipv6.ProtocolNumber: - c.sendV6Packet(payload, hdrs, device) - default: - t.Fatalf("unexpected protocol number: %d", protoNum) - } - - ep := <-pollChannel - if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil { - t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) - } - stats[ep]++ - if i < nports { - endpoints[uint16(i)] = ep - } else { - // Check that all packets from one client are handled by the same - // socket. - if want, got := endpoints[port], ep; want != got { - t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) - } - } - } - - // Check that a packet distribution is as expected. - for ep, i := range eps { - wantRatio := wantDistribution[i] - wantRecv := wantRatio * float64(npackets) - actualRecv := stats[ep] - actualRatio := float64(stats[ep]) / float64(npackets) - // The deviation is less than 10%. - if math.Abs(actualRatio-wantRatio) > 0.05 { - t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) - } - } - }) - } - } - } -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go deleted file mode 100644 index 839178809..000000000 --- a/pkg/tcpip/stack/transport_test.go +++ /dev/null @@ -1,555 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "bytes" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen int = 3 -) - -// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't -// use it. -type fakeTransportEndpoint struct { - stack.TransportEndpointInfo - tcpip.DefaultSocketOptionsHandler - - proto *fakeTransportProtocol - peerAddr tcpip.Address - route *stack.Route - uniqueID uint64 - - // acceptQueue is non-nil iff bound. - acceptQueue []*fakeTransportEndpoint - - // ops is used to set and get socket options. - ops tcpip.SocketOptions -} - -func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { - return &f.TransportEndpointInfo -} - -func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { - return nil -} - -func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} - -func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { - return &f.ops -} - -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - return ep -} - -func (f *fakeTransportEndpoint) Abort() { - f.Close() -} - -func (f *fakeTransportEndpoint) Close() { - // TODO(gvisor.dev/issue/5153): Consider retaining the route. - f.route.Release() -} - -func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask -} - -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { - return tcpip.ReadResult{}, nil -} - -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if len(f.route.RemoteAddress()) == 0 { - return 0, &tcpip.ErrNoRoute{} - } - - v := make([]byte, p.Len()) - if _, err := io.ReadFull(p, v); err != nil { - return 0, &tcpip.ErrBadBuffer{} - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, - Data: buffer.View(v).ToVectorisedView(), - }) - _ = pkt.TransportHeader().Push(fakeTransHeaderLen) - if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { - return 0, err - } - - return int64(len(v)), nil -} - -// SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { - return -1, &tcpip.ErrUnknownProtocolOption{} -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { - f.peerAddr = addr.Addr - - // Find the route. - r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return &tcpip.ErrNoRoute{} - } - - // Try to register so that we can start receiving packets. - f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) - if err != nil { - r.Release() - return err - } - - f.route = r - - return nil -} - -func (f *fakeTransportEndpoint) UniqueID() uint64 { - return f.uniqueID -} - -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Reset() { -} - -func (*fakeTransportEndpoint) Listen(int) tcpip.Error { - return nil -} - -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { - if len(f.acceptQueue) == 0 { - return nil, nil, nil - } - a := f.acceptQueue[0] - f.acceptQueue = f.acceptQueue[1:] - return a, nil, nil -} - -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { - if err := f.proto.stack.RegisterTransportEndpoint( - []tcpip.NetworkProtocolNumber{fakeNetNumber}, - fakeTransNumber, - stack.TransportEndpointID{LocalAddress: a.Addr}, - f, - ports.Flags{}, - 0, /* bindtoDevice */ - ); err != nil { - return err - } - f.acceptQueue = []*fakeTransportEndpoint{} - return nil -} - -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - // Increment the number of received packets. - f.proto.packetCount++ - if f.acceptQueue == nil { - return - } - - netHdr := pkt.NetworkHeader().View() - route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil { - return - } - - ep := &fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: route.RemoteAddress(), - route: route, - } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - f.acceptQueue = append(f.acceptQueue, ep) -} - -func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { - // Increment the number of received control packets. - f.proto.controlCount++ -} - -func (*fakeTransportEndpoint) State() uint32 { - return 0 -} - -func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} - -func (*fakeTransportEndpoint) Resume(*stack.Stack) {} - -func (*fakeTransportEndpoint) Wait() {} - -func (*fakeTransportEndpoint) LastError() tcpip.Error { - return nil -} - -type fakeTransportGoodOption bool - -type fakeTransportBadOption bool - -type fakeTransportInvalidValueOption int - -type fakeTransportProtocolOptions struct { - good bool -} - -// fakeTransportProtocol is a transport-layer protocol descriptor. It -// aggregates the number of packets received via endpoints of this protocol. -type fakeTransportProtocol struct { - stack *stack.Stack - - packetCount int - controlCount int - opts fakeTransportProtocolOptions -} - -func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { - return fakeTransNumber -} - -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack), nil -} - -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return nil, &tcpip.ErrUnknownProtocol{} -} - -func (*fakeTransportProtocol) MinimumPacketSize() int { - return fakeTransHeaderLen -} - -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { - return 0, 0, nil -} - -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - return stack.UnknownDestinationPacketHandled -} - -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.TCPModerateReceiveBufferOption: - f.opts.good = bool(*v) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.TCPModerateReceiveBufferOption: - *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -// Abort implements TransportProtocol.Abort. -func (*fakeTransportProtocol) Abort() {} - -// Close implements tcpip.Endpoint.Close. -func (*fakeTransportProtocol) Close() {} - -// Wait implements TransportProtocol.Wait. -func (*fakeTransportProtocol) Wait() {} - -// Parse implements TransportProtocol.Parse. -func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok -} - -func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { - return &fakeTransportProtocol{stack: s} -} - -func TestTransportReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the packet. - buf := buffer.NewView(30) - - // Make sure packet with wrong protocol is not delivered. - buf[0] = 1 - buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[0] = 1 - buf[1] = 3 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet is delivered. - buf[0] = 1 - buf[1] = 2 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 1 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) - } -} - -func TestTransportControlReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the control packet. - buf := buffer.NewView(2*fakeNetHeaderLen + 30) - - // Outer packet contains the control protocol number. - buf[0] = 1 - buf[1] = 0xfe - buf[2] = uint8(fakeControlProtocol) - - // Make sure packet with wrong protocol is not delivered. - buf[fakeNetHeaderLen+0] = 0 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[fakeNetHeaderLen+0] = 3 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet is delivered. - buf[fakeNetHeaderLen+0] = 2 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 1 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) - } -} - -func TestTransportSend(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Create endpoint and bind it. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Create buffer that will hold the payload. - b := make([]byte, 30) - var r bytes.Reader - r.Reset(b) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("write failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - if fakeNet.sendPacketCount[2] != 1 { - t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) - } -} - -func TestTransportOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - - v := tcpip.TCPModerateReceiveBufferOption(true) - if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) - } - v = false - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) - } - if !v { - t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") - } -} diff --git a/pkg/tcpip/stack/tuple_list.go b/pkg/tcpip/stack/tuple_list.go new file mode 100644 index 000000000..31d0feefa --- /dev/null +++ b/pkg/tcpip/stack/tuple_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type tupleElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (tupleElementMapper) linkerFor(elem *tuple) *tuple { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type tupleList struct { + head *tuple + tail *tuple +} + +// Reset resets list l to the empty state. +func (l *tupleList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *tupleList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *tupleList) Front() *tuple { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *tupleList) Back() *tuple { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *tupleList) Len() (count int) { + for e := l.Front(); e != nil; e = (tupleElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *tupleList) PushFront(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + tupleElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *tupleList) PushBack(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + tupleElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *tupleList) PushBackList(m *tupleList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + tupleElementMapper{}.linkerFor(l.tail).SetNext(m.head) + tupleElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *tupleList) InsertAfter(b, e *tuple) { + bLinker := tupleElementMapper{}.linkerFor(b) + eLinker := tupleElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + tupleElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *tupleList) InsertBefore(a, e *tuple) { + aLinker := tupleElementMapper{}.linkerFor(a) + eLinker := tupleElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + tupleElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *tupleList) Remove(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + tupleElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + tupleElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type tupleEntry struct { + next *tuple + prev *tuple +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *tupleEntry) Next() *tuple { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *tupleEntry) Prev() *tuple { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *tupleEntry) SetNext(elem *tuple) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *tupleEntry) SetPrev(elem *tuple) { + e.prev = elem +} diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go new file mode 100644 index 000000000..9c34de9ef --- /dev/null +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -0,0 +1,1323 @@ +// automatically generated by stateify. + +package tcpip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *ErrAborted) StateTypeName() string { + return "pkg/tcpip.ErrAborted" +} + +func (e *ErrAborted) StateFields() []string { + return []string{} +} + +func (e *ErrAborted) beforeSave() {} + +// +checklocksignore +func (e *ErrAborted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAborted) afterLoad() {} + +// +checklocksignore +func (e *ErrAborted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAddressFamilyNotSupported) StateTypeName() string { + return "pkg/tcpip.ErrAddressFamilyNotSupported" +} + +func (e *ErrAddressFamilyNotSupported) StateFields() []string { + return []string{} +} + +func (e *ErrAddressFamilyNotSupported) beforeSave() {} + +// +checklocksignore +func (e *ErrAddressFamilyNotSupported) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAddressFamilyNotSupported) afterLoad() {} + +// +checklocksignore +func (e *ErrAddressFamilyNotSupported) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAlreadyBound) StateTypeName() string { + return "pkg/tcpip.ErrAlreadyBound" +} + +func (e *ErrAlreadyBound) StateFields() []string { + return []string{} +} + +func (e *ErrAlreadyBound) beforeSave() {} + +// +checklocksignore +func (e *ErrAlreadyBound) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAlreadyBound) afterLoad() {} + +// +checklocksignore +func (e *ErrAlreadyBound) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAlreadyConnected) StateTypeName() string { + return "pkg/tcpip.ErrAlreadyConnected" +} + +func (e *ErrAlreadyConnected) StateFields() []string { + return []string{} +} + +func (e *ErrAlreadyConnected) beforeSave() {} + +// +checklocksignore +func (e *ErrAlreadyConnected) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAlreadyConnected) afterLoad() {} + +// +checklocksignore +func (e *ErrAlreadyConnected) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrAlreadyConnecting) StateTypeName() string { + return "pkg/tcpip.ErrAlreadyConnecting" +} + +func (e *ErrAlreadyConnecting) StateFields() []string { + return []string{} +} + +func (e *ErrAlreadyConnecting) beforeSave() {} + +// +checklocksignore +func (e *ErrAlreadyConnecting) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrAlreadyConnecting) afterLoad() {} + +// +checklocksignore +func (e *ErrAlreadyConnecting) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBadAddress) StateTypeName() string { + return "pkg/tcpip.ErrBadAddress" +} + +func (e *ErrBadAddress) StateFields() []string { + return []string{} +} + +func (e *ErrBadAddress) beforeSave() {} + +// +checklocksignore +func (e *ErrBadAddress) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBadAddress) afterLoad() {} + +// +checklocksignore +func (e *ErrBadAddress) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBadBuffer) StateTypeName() string { + return "pkg/tcpip.ErrBadBuffer" +} + +func (e *ErrBadBuffer) StateFields() []string { + return []string{} +} + +func (e *ErrBadBuffer) beforeSave() {} + +// +checklocksignore +func (e *ErrBadBuffer) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBadBuffer) afterLoad() {} + +// +checklocksignore +func (e *ErrBadBuffer) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBadLocalAddress) StateTypeName() string { + return "pkg/tcpip.ErrBadLocalAddress" +} + +func (e *ErrBadLocalAddress) StateFields() []string { + return []string{} +} + +func (e *ErrBadLocalAddress) beforeSave() {} + +// +checklocksignore +func (e *ErrBadLocalAddress) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBadLocalAddress) afterLoad() {} + +// +checklocksignore +func (e *ErrBadLocalAddress) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrBroadcastDisabled) StateTypeName() string { + return "pkg/tcpip.ErrBroadcastDisabled" +} + +func (e *ErrBroadcastDisabled) StateFields() []string { + return []string{} +} + +func (e *ErrBroadcastDisabled) beforeSave() {} + +// +checklocksignore +func (e *ErrBroadcastDisabled) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrBroadcastDisabled) afterLoad() {} + +// +checklocksignore +func (e *ErrBroadcastDisabled) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrClosedForReceive) StateTypeName() string { + return "pkg/tcpip.ErrClosedForReceive" +} + +func (e *ErrClosedForReceive) StateFields() []string { + return []string{} +} + +func (e *ErrClosedForReceive) beforeSave() {} + +// +checklocksignore +func (e *ErrClosedForReceive) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrClosedForReceive) afterLoad() {} + +// +checklocksignore +func (e *ErrClosedForReceive) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrClosedForSend) StateTypeName() string { + return "pkg/tcpip.ErrClosedForSend" +} + +func (e *ErrClosedForSend) StateFields() []string { + return []string{} +} + +func (e *ErrClosedForSend) beforeSave() {} + +// +checklocksignore +func (e *ErrClosedForSend) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrClosedForSend) afterLoad() {} + +// +checklocksignore +func (e *ErrClosedForSend) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectStarted) StateTypeName() string { + return "pkg/tcpip.ErrConnectStarted" +} + +func (e *ErrConnectStarted) StateFields() []string { + return []string{} +} + +func (e *ErrConnectStarted) beforeSave() {} + +// +checklocksignore +func (e *ErrConnectStarted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectStarted) afterLoad() {} + +// +checklocksignore +func (e *ErrConnectStarted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectionAborted) StateTypeName() string { + return "pkg/tcpip.ErrConnectionAborted" +} + +func (e *ErrConnectionAborted) StateFields() []string { + return []string{} +} + +func (e *ErrConnectionAborted) beforeSave() {} + +// +checklocksignore +func (e *ErrConnectionAborted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectionAborted) afterLoad() {} + +// +checklocksignore +func (e *ErrConnectionAborted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectionRefused) StateTypeName() string { + return "pkg/tcpip.ErrConnectionRefused" +} + +func (e *ErrConnectionRefused) StateFields() []string { + return []string{} +} + +func (e *ErrConnectionRefused) beforeSave() {} + +// +checklocksignore +func (e *ErrConnectionRefused) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectionRefused) afterLoad() {} + +// +checklocksignore +func (e *ErrConnectionRefused) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrConnectionReset) StateTypeName() string { + return "pkg/tcpip.ErrConnectionReset" +} + +func (e *ErrConnectionReset) StateFields() []string { + return []string{} +} + +func (e *ErrConnectionReset) beforeSave() {} + +// +checklocksignore +func (e *ErrConnectionReset) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrConnectionReset) afterLoad() {} + +// +checklocksignore +func (e *ErrConnectionReset) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrDestinationRequired) StateTypeName() string { + return "pkg/tcpip.ErrDestinationRequired" +} + +func (e *ErrDestinationRequired) StateFields() []string { + return []string{} +} + +func (e *ErrDestinationRequired) beforeSave() {} + +// +checklocksignore +func (e *ErrDestinationRequired) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrDestinationRequired) afterLoad() {} + +// +checklocksignore +func (e *ErrDestinationRequired) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrDuplicateAddress) StateTypeName() string { + return "pkg/tcpip.ErrDuplicateAddress" +} + +func (e *ErrDuplicateAddress) StateFields() []string { + return []string{} +} + +func (e *ErrDuplicateAddress) beforeSave() {} + +// +checklocksignore +func (e *ErrDuplicateAddress) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrDuplicateAddress) afterLoad() {} + +// +checklocksignore +func (e *ErrDuplicateAddress) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrDuplicateNICID) StateTypeName() string { + return "pkg/tcpip.ErrDuplicateNICID" +} + +func (e *ErrDuplicateNICID) StateFields() []string { + return []string{} +} + +func (e *ErrDuplicateNICID) beforeSave() {} + +// +checklocksignore +func (e *ErrDuplicateNICID) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrDuplicateNICID) afterLoad() {} + +// +checklocksignore +func (e *ErrDuplicateNICID) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrInvalidEndpointState) StateTypeName() string { + return "pkg/tcpip.ErrInvalidEndpointState" +} + +func (e *ErrInvalidEndpointState) StateFields() []string { + return []string{} +} + +func (e *ErrInvalidEndpointState) beforeSave() {} + +// +checklocksignore +func (e *ErrInvalidEndpointState) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrInvalidEndpointState) afterLoad() {} + +// +checklocksignore +func (e *ErrInvalidEndpointState) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrInvalidOptionValue) StateTypeName() string { + return "pkg/tcpip.ErrInvalidOptionValue" +} + +func (e *ErrInvalidOptionValue) StateFields() []string { + return []string{} +} + +func (e *ErrInvalidOptionValue) beforeSave() {} + +// +checklocksignore +func (e *ErrInvalidOptionValue) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrInvalidOptionValue) afterLoad() {} + +// +checklocksignore +func (e *ErrInvalidOptionValue) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrInvalidPortRange) StateTypeName() string { + return "pkg/tcpip.ErrInvalidPortRange" +} + +func (e *ErrInvalidPortRange) StateFields() []string { + return []string{} +} + +func (e *ErrInvalidPortRange) beforeSave() {} + +// +checklocksignore +func (e *ErrInvalidPortRange) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrInvalidPortRange) afterLoad() {} + +// +checklocksignore +func (e *ErrInvalidPortRange) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrMalformedHeader) StateTypeName() string { + return "pkg/tcpip.ErrMalformedHeader" +} + +func (e *ErrMalformedHeader) StateFields() []string { + return []string{} +} + +func (e *ErrMalformedHeader) beforeSave() {} + +// +checklocksignore +func (e *ErrMalformedHeader) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrMalformedHeader) afterLoad() {} + +// +checklocksignore +func (e *ErrMalformedHeader) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrMessageTooLong) StateTypeName() string { + return "pkg/tcpip.ErrMessageTooLong" +} + +func (e *ErrMessageTooLong) StateFields() []string { + return []string{} +} + +func (e *ErrMessageTooLong) beforeSave() {} + +// +checklocksignore +func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrMessageTooLong) afterLoad() {} + +// +checklocksignore +func (e *ErrMessageTooLong) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNetworkUnreachable) StateTypeName() string { + return "pkg/tcpip.ErrNetworkUnreachable" +} + +func (e *ErrNetworkUnreachable) StateFields() []string { + return []string{} +} + +func (e *ErrNetworkUnreachable) beforeSave() {} + +// +checklocksignore +func (e *ErrNetworkUnreachable) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNetworkUnreachable) afterLoad() {} + +// +checklocksignore +func (e *ErrNetworkUnreachable) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoBufferSpace) StateTypeName() string { + return "pkg/tcpip.ErrNoBufferSpace" +} + +func (e *ErrNoBufferSpace) StateFields() []string { + return []string{} +} + +func (e *ErrNoBufferSpace) beforeSave() {} + +// +checklocksignore +func (e *ErrNoBufferSpace) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoBufferSpace) afterLoad() {} + +// +checklocksignore +func (e *ErrNoBufferSpace) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoPortAvailable) StateTypeName() string { + return "pkg/tcpip.ErrNoPortAvailable" +} + +func (e *ErrNoPortAvailable) StateFields() []string { + return []string{} +} + +func (e *ErrNoPortAvailable) beforeSave() {} + +// +checklocksignore +func (e *ErrNoPortAvailable) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoPortAvailable) afterLoad() {} + +// +checklocksignore +func (e *ErrNoPortAvailable) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoRoute) StateTypeName() string { + return "pkg/tcpip.ErrNoRoute" +} + +func (e *ErrNoRoute) StateFields() []string { + return []string{} +} + +func (e *ErrNoRoute) beforeSave() {} + +// +checklocksignore +func (e *ErrNoRoute) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoRoute) afterLoad() {} + +// +checklocksignore +func (e *ErrNoRoute) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNoSuchFile) StateTypeName() string { + return "pkg/tcpip.ErrNoSuchFile" +} + +func (e *ErrNoSuchFile) StateFields() []string { + return []string{} +} + +func (e *ErrNoSuchFile) beforeSave() {} + +// +checklocksignore +func (e *ErrNoSuchFile) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNoSuchFile) afterLoad() {} + +// +checklocksignore +func (e *ErrNoSuchFile) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNotConnected) StateTypeName() string { + return "pkg/tcpip.ErrNotConnected" +} + +func (e *ErrNotConnected) StateFields() []string { + return []string{} +} + +func (e *ErrNotConnected) beforeSave() {} + +// +checklocksignore +func (e *ErrNotConnected) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNotConnected) afterLoad() {} + +// +checklocksignore +func (e *ErrNotConnected) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNotPermitted) StateTypeName() string { + return "pkg/tcpip.ErrNotPermitted" +} + +func (e *ErrNotPermitted) StateFields() []string { + return []string{} +} + +func (e *ErrNotPermitted) beforeSave() {} + +// +checklocksignore +func (e *ErrNotPermitted) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNotPermitted) afterLoad() {} + +// +checklocksignore +func (e *ErrNotPermitted) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrNotSupported) StateTypeName() string { + return "pkg/tcpip.ErrNotSupported" +} + +func (e *ErrNotSupported) StateFields() []string { + return []string{} +} + +func (e *ErrNotSupported) beforeSave() {} + +// +checklocksignore +func (e *ErrNotSupported) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrNotSupported) afterLoad() {} + +// +checklocksignore +func (e *ErrNotSupported) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrPortInUse) StateTypeName() string { + return "pkg/tcpip.ErrPortInUse" +} + +func (e *ErrPortInUse) StateFields() []string { + return []string{} +} + +func (e *ErrPortInUse) beforeSave() {} + +// +checklocksignore +func (e *ErrPortInUse) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrPortInUse) afterLoad() {} + +// +checklocksignore +func (e *ErrPortInUse) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrQueueSizeNotSupported) StateTypeName() string { + return "pkg/tcpip.ErrQueueSizeNotSupported" +} + +func (e *ErrQueueSizeNotSupported) StateFields() []string { + return []string{} +} + +func (e *ErrQueueSizeNotSupported) beforeSave() {} + +// +checklocksignore +func (e *ErrQueueSizeNotSupported) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrQueueSizeNotSupported) afterLoad() {} + +// +checklocksignore +func (e *ErrQueueSizeNotSupported) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrTimeout) StateTypeName() string { + return "pkg/tcpip.ErrTimeout" +} + +func (e *ErrTimeout) StateFields() []string { + return []string{} +} + +func (e *ErrTimeout) beforeSave() {} + +// +checklocksignore +func (e *ErrTimeout) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrTimeout) afterLoad() {} + +// +checklocksignore +func (e *ErrTimeout) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownDevice) StateTypeName() string { + return "pkg/tcpip.ErrUnknownDevice" +} + +func (e *ErrUnknownDevice) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownDevice) beforeSave() {} + +// +checklocksignore +func (e *ErrUnknownDevice) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownDevice) afterLoad() {} + +// +checklocksignore +func (e *ErrUnknownDevice) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownNICID) StateTypeName() string { + return "pkg/tcpip.ErrUnknownNICID" +} + +func (e *ErrUnknownNICID) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownNICID) beforeSave() {} + +// +checklocksignore +func (e *ErrUnknownNICID) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownNICID) afterLoad() {} + +// +checklocksignore +func (e *ErrUnknownNICID) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownProtocol) StateTypeName() string { + return "pkg/tcpip.ErrUnknownProtocol" +} + +func (e *ErrUnknownProtocol) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownProtocol) beforeSave() {} + +// +checklocksignore +func (e *ErrUnknownProtocol) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownProtocol) afterLoad() {} + +// +checklocksignore +func (e *ErrUnknownProtocol) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrUnknownProtocolOption) StateTypeName() string { + return "pkg/tcpip.ErrUnknownProtocolOption" +} + +func (e *ErrUnknownProtocolOption) StateFields() []string { + return []string{} +} + +func (e *ErrUnknownProtocolOption) beforeSave() {} + +// +checklocksignore +func (e *ErrUnknownProtocolOption) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrUnknownProtocolOption) afterLoad() {} + +// +checklocksignore +func (e *ErrUnknownProtocolOption) StateLoad(stateSourceObject state.Source) { +} + +func (e *ErrWouldBlock) StateTypeName() string { + return "pkg/tcpip.ErrWouldBlock" +} + +func (e *ErrWouldBlock) StateFields() []string { + return []string{} +} + +func (e *ErrWouldBlock) beforeSave() {} + +// +checklocksignore +func (e *ErrWouldBlock) StateSave(stateSinkObject state.Sink) { + e.beforeSave() +} + +func (e *ErrWouldBlock) afterLoad() {} + +// +checklocksignore +func (e *ErrWouldBlock) StateLoad(stateSourceObject state.Source) { +} + +func (l *sockErrorList) StateTypeName() string { + return "pkg/tcpip.sockErrorList" +} + +func (l *sockErrorList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *sockErrorList) beforeSave() {} + +// +checklocksignore +func (l *sockErrorList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *sockErrorList) afterLoad() {} + +// +checklocksignore +func (l *sockErrorList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *sockErrorEntry) StateTypeName() string { + return "pkg/tcpip.sockErrorEntry" +} + +func (e *sockErrorEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *sockErrorEntry) beforeSave() {} + +// +checklocksignore +func (e *sockErrorEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *sockErrorEntry) afterLoad() {} + +// +checklocksignore +func (e *sockErrorEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (so *SocketOptions) StateTypeName() string { + return "pkg/tcpip.SocketOptions" +} + +func (so *SocketOptions) StateFields() []string { + return []string{ + "handler", + "broadcastEnabled", + "passCredEnabled", + "noChecksumEnabled", + "reuseAddressEnabled", + "reusePortEnabled", + "keepAliveEnabled", + "multicastLoopEnabled", + "receiveTOSEnabled", + "receiveTClassEnabled", + "receivePacketInfoEnabled", + "hdrIncludedEnabled", + "v6OnlyEnabled", + "quickAckEnabled", + "delayOptionEnabled", + "corkOptionEnabled", + "receiveOriginalDstAddress", + "recvErrEnabled", + "errQueue", + "bindToDevice", + "sendBufferSize", + "receiveBufferSize", + "linger", + } +} + +func (so *SocketOptions) beforeSave() {} + +// +checklocksignore +func (so *SocketOptions) StateSave(stateSinkObject state.Sink) { + so.beforeSave() + stateSinkObject.Save(0, &so.handler) + stateSinkObject.Save(1, &so.broadcastEnabled) + stateSinkObject.Save(2, &so.passCredEnabled) + stateSinkObject.Save(3, &so.noChecksumEnabled) + stateSinkObject.Save(4, &so.reuseAddressEnabled) + stateSinkObject.Save(5, &so.reusePortEnabled) + stateSinkObject.Save(6, &so.keepAliveEnabled) + stateSinkObject.Save(7, &so.multicastLoopEnabled) + stateSinkObject.Save(8, &so.receiveTOSEnabled) + stateSinkObject.Save(9, &so.receiveTClassEnabled) + stateSinkObject.Save(10, &so.receivePacketInfoEnabled) + stateSinkObject.Save(11, &so.hdrIncludedEnabled) + stateSinkObject.Save(12, &so.v6OnlyEnabled) + stateSinkObject.Save(13, &so.quickAckEnabled) + stateSinkObject.Save(14, &so.delayOptionEnabled) + stateSinkObject.Save(15, &so.corkOptionEnabled) + stateSinkObject.Save(16, &so.receiveOriginalDstAddress) + stateSinkObject.Save(17, &so.recvErrEnabled) + stateSinkObject.Save(18, &so.errQueue) + stateSinkObject.Save(19, &so.bindToDevice) + stateSinkObject.Save(20, &so.sendBufferSize) + stateSinkObject.Save(21, &so.receiveBufferSize) + stateSinkObject.Save(22, &so.linger) +} + +func (so *SocketOptions) afterLoad() {} + +// +checklocksignore +func (so *SocketOptions) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &so.handler) + stateSourceObject.Load(1, &so.broadcastEnabled) + stateSourceObject.Load(2, &so.passCredEnabled) + stateSourceObject.Load(3, &so.noChecksumEnabled) + stateSourceObject.Load(4, &so.reuseAddressEnabled) + stateSourceObject.Load(5, &so.reusePortEnabled) + stateSourceObject.Load(6, &so.keepAliveEnabled) + stateSourceObject.Load(7, &so.multicastLoopEnabled) + stateSourceObject.Load(8, &so.receiveTOSEnabled) + stateSourceObject.Load(9, &so.receiveTClassEnabled) + stateSourceObject.Load(10, &so.receivePacketInfoEnabled) + stateSourceObject.Load(11, &so.hdrIncludedEnabled) + stateSourceObject.Load(12, &so.v6OnlyEnabled) + stateSourceObject.Load(13, &so.quickAckEnabled) + stateSourceObject.Load(14, &so.delayOptionEnabled) + stateSourceObject.Load(15, &so.corkOptionEnabled) + stateSourceObject.Load(16, &so.receiveOriginalDstAddress) + stateSourceObject.Load(17, &so.recvErrEnabled) + stateSourceObject.Load(18, &so.errQueue) + stateSourceObject.Load(19, &so.bindToDevice) + stateSourceObject.Load(20, &so.sendBufferSize) + stateSourceObject.Load(21, &so.receiveBufferSize) + stateSourceObject.Load(22, &so.linger) +} + +func (l *LocalSockError) StateTypeName() string { + return "pkg/tcpip.LocalSockError" +} + +func (l *LocalSockError) StateFields() []string { + return []string{ + "info", + } +} + +func (l *LocalSockError) beforeSave() {} + +// +checklocksignore +func (l *LocalSockError) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.info) +} + +func (l *LocalSockError) afterLoad() {} + +// +checklocksignore +func (l *LocalSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.info) +} + +func (s *SockError) StateTypeName() string { + return "pkg/tcpip.SockError" +} + +func (s *SockError) StateFields() []string { + return []string{ + "sockErrorEntry", + "Err", + "Cause", + "Payload", + "Dst", + "Offender", + "NetProto", + } +} + +func (s *SockError) beforeSave() {} + +// +checklocksignore +func (s *SockError) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.sockErrorEntry) + stateSinkObject.Save(1, &s.Err) + stateSinkObject.Save(2, &s.Cause) + stateSinkObject.Save(3, &s.Payload) + stateSinkObject.Save(4, &s.Dst) + stateSinkObject.Save(5, &s.Offender) + stateSinkObject.Save(6, &s.NetProto) +} + +func (s *SockError) afterLoad() {} + +// +checklocksignore +func (s *SockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.sockErrorEntry) + stateSourceObject.Load(1, &s.Err) + stateSourceObject.Load(2, &s.Cause) + stateSourceObject.Load(3, &s.Payload) + stateSourceObject.Load(4, &s.Dst) + stateSourceObject.Load(5, &s.Offender) + stateSourceObject.Load(6, &s.NetProto) +} + +func (s *stdClock) StateTypeName() string { + return "pkg/tcpip.stdClock" +} + +func (s *stdClock) StateFields() []string { + return []string{ + "maxMonotonic", + } +} + +func (s *stdClock) beforeSave() {} + +// +checklocksignore +func (s *stdClock) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.maxMonotonic) +} + +// +checklocksignore +func (s *stdClock) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.maxMonotonic) + stateSourceObject.AfterLoad(s.afterLoad) +} + +func (mt *MonotonicTime) StateTypeName() string { + return "pkg/tcpip.MonotonicTime" +} + +func (mt *MonotonicTime) StateFields() []string { + return []string{ + "nanoseconds", + } +} + +func (mt *MonotonicTime) beforeSave() {} + +// +checklocksignore +func (mt *MonotonicTime) StateSave(stateSinkObject state.Sink) { + mt.beforeSave() + stateSinkObject.Save(0, &mt.nanoseconds) +} + +func (mt *MonotonicTime) afterLoad() {} + +// +checklocksignore +func (mt *MonotonicTime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &mt.nanoseconds) +} + +func (f *FullAddress) StateTypeName() string { + return "pkg/tcpip.FullAddress" +} + +func (f *FullAddress) StateFields() []string { + return []string{ + "NIC", + "Addr", + "Port", + } +} + +func (f *FullAddress) beforeSave() {} + +// +checklocksignore +func (f *FullAddress) StateSave(stateSinkObject state.Sink) { + f.beforeSave() + stateSinkObject.Save(0, &f.NIC) + stateSinkObject.Save(1, &f.Addr) + stateSinkObject.Save(2, &f.Port) +} + +func (f *FullAddress) afterLoad() {} + +// +checklocksignore +func (f *FullAddress) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &f.NIC) + stateSourceObject.Load(1, &f.Addr) + stateSourceObject.Load(2, &f.Port) +} + +func (c *ControlMessages) StateTypeName() string { + return "pkg/tcpip.ControlMessages" +} + +func (c *ControlMessages) StateFields() []string { + return []string{ + "HasTimestamp", + "Timestamp", + "HasInq", + "Inq", + "HasTOS", + "TOS", + "HasTClass", + "TClass", + "HasIPPacketInfo", + "PacketInfo", + "HasOriginalDstAddress", + "OriginalDstAddress", + "SockErr", + } +} + +func (c *ControlMessages) beforeSave() {} + +// +checklocksignore +func (c *ControlMessages) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.HasTimestamp) + stateSinkObject.Save(1, &c.Timestamp) + stateSinkObject.Save(2, &c.HasInq) + stateSinkObject.Save(3, &c.Inq) + stateSinkObject.Save(4, &c.HasTOS) + stateSinkObject.Save(5, &c.TOS) + stateSinkObject.Save(6, &c.HasTClass) + stateSinkObject.Save(7, &c.TClass) + stateSinkObject.Save(8, &c.HasIPPacketInfo) + stateSinkObject.Save(9, &c.PacketInfo) + stateSinkObject.Save(10, &c.HasOriginalDstAddress) + stateSinkObject.Save(11, &c.OriginalDstAddress) + stateSinkObject.Save(12, &c.SockErr) +} + +func (c *ControlMessages) afterLoad() {} + +// +checklocksignore +func (c *ControlMessages) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.HasTimestamp) + stateSourceObject.Load(1, &c.Timestamp) + stateSourceObject.Load(2, &c.HasInq) + stateSourceObject.Load(3, &c.Inq) + stateSourceObject.Load(4, &c.HasTOS) + stateSourceObject.Load(5, &c.TOS) + stateSourceObject.Load(6, &c.HasTClass) + stateSourceObject.Load(7, &c.TClass) + stateSourceObject.Load(8, &c.HasIPPacketInfo) + stateSourceObject.Load(9, &c.PacketInfo) + stateSourceObject.Load(10, &c.HasOriginalDstAddress) + stateSourceObject.Load(11, &c.OriginalDstAddress) + stateSourceObject.Load(12, &c.SockErr) +} + +func (l *LinkPacketInfo) StateTypeName() string { + return "pkg/tcpip.LinkPacketInfo" +} + +func (l *LinkPacketInfo) StateFields() []string { + return []string{ + "Protocol", + "PktType", + } +} + +func (l *LinkPacketInfo) beforeSave() {} + +// +checklocksignore +func (l *LinkPacketInfo) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Protocol) + stateSinkObject.Save(1, &l.PktType) +} + +func (l *LinkPacketInfo) afterLoad() {} + +// +checklocksignore +func (l *LinkPacketInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Protocol) + stateSourceObject.Load(1, &l.PktType) +} + +func (l *LingerOption) StateTypeName() string { + return "pkg/tcpip.LingerOption" +} + +func (l *LingerOption) StateFields() []string { + return []string{ + "Enabled", + "Timeout", + } +} + +func (l *LingerOption) beforeSave() {} + +// +checklocksignore +func (l *LingerOption) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.Enabled) + stateSinkObject.Save(1, &l.Timeout) +} + +func (l *LingerOption) afterLoad() {} + +// +checklocksignore +func (l *LingerOption) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.Enabled) + stateSourceObject.Load(1, &l.Timeout) +} + +func (i *IPPacketInfo) StateTypeName() string { + return "pkg/tcpip.IPPacketInfo" +} + +func (i *IPPacketInfo) StateFields() []string { + return []string{ + "NIC", + "LocalAddr", + "DestinationAddr", + } +} + +func (i *IPPacketInfo) beforeSave() {} + +// +checklocksignore +func (i *IPPacketInfo) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.NIC) + stateSinkObject.Save(1, &i.LocalAddr) + stateSinkObject.Save(2, &i.DestinationAddr) +} + +func (i *IPPacketInfo) afterLoad() {} + +// +checklocksignore +func (i *IPPacketInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.NIC) + stateSourceObject.Load(1, &i.LocalAddr) + stateSourceObject.Load(2, &i.DestinationAddr) +} + +func init() { + state.Register((*ErrAborted)(nil)) + state.Register((*ErrAddressFamilyNotSupported)(nil)) + state.Register((*ErrAlreadyBound)(nil)) + state.Register((*ErrAlreadyConnected)(nil)) + state.Register((*ErrAlreadyConnecting)(nil)) + state.Register((*ErrBadAddress)(nil)) + state.Register((*ErrBadBuffer)(nil)) + state.Register((*ErrBadLocalAddress)(nil)) + state.Register((*ErrBroadcastDisabled)(nil)) + state.Register((*ErrClosedForReceive)(nil)) + state.Register((*ErrClosedForSend)(nil)) + state.Register((*ErrConnectStarted)(nil)) + state.Register((*ErrConnectionAborted)(nil)) + state.Register((*ErrConnectionRefused)(nil)) + state.Register((*ErrConnectionReset)(nil)) + state.Register((*ErrDestinationRequired)(nil)) + state.Register((*ErrDuplicateAddress)(nil)) + state.Register((*ErrDuplicateNICID)(nil)) + state.Register((*ErrInvalidEndpointState)(nil)) + state.Register((*ErrInvalidOptionValue)(nil)) + state.Register((*ErrInvalidPortRange)(nil)) + state.Register((*ErrMalformedHeader)(nil)) + state.Register((*ErrMessageTooLong)(nil)) + state.Register((*ErrNetworkUnreachable)(nil)) + state.Register((*ErrNoBufferSpace)(nil)) + state.Register((*ErrNoPortAvailable)(nil)) + state.Register((*ErrNoRoute)(nil)) + state.Register((*ErrNoSuchFile)(nil)) + state.Register((*ErrNotConnected)(nil)) + state.Register((*ErrNotPermitted)(nil)) + state.Register((*ErrNotSupported)(nil)) + state.Register((*ErrPortInUse)(nil)) + state.Register((*ErrQueueSizeNotSupported)(nil)) + state.Register((*ErrTimeout)(nil)) + state.Register((*ErrUnknownDevice)(nil)) + state.Register((*ErrUnknownNICID)(nil)) + state.Register((*ErrUnknownProtocol)(nil)) + state.Register((*ErrUnknownProtocolOption)(nil)) + state.Register((*ErrWouldBlock)(nil)) + state.Register((*sockErrorList)(nil)) + state.Register((*sockErrorEntry)(nil)) + state.Register((*SocketOptions)(nil)) + state.Register((*LocalSockError)(nil)) + state.Register((*SockError)(nil)) + state.Register((*stdClock)(nil)) + state.Register((*MonotonicTime)(nil)) + state.Register((*FullAddress)(nil)) + state.Register((*ControlMessages)(nil)) + state.Register((*LinkPacketInfo)(nil)) + state.Register((*LingerOption)(nil)) + state.Register((*IPPacketInfo)(nil)) +} diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go deleted file mode 100644 index c96ae2f02..000000000 --- a/pkg/tcpip/tcpip_test.go +++ /dev/null @@ -1,325 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestLimitedWriter_Write(t *testing.T) { - var b bytes.Buffer - l := LimitedWriter{ - W: &b, - N: 5, - } - if n, err := l.Write([]byte{0, 1, 2}); err != nil { - t.Errorf("got l.Write(3/5) = (_, %s), want nil", err) - } else if n != 3 { - t.Errorf("got l.Write(3/5) = (%d, _), want 3", n) - } - if n, err := l.Write([]byte{3, 4, 5}); err != io.ErrShortWrite { - t.Errorf("got l.Write(3/2) = (_, %s), want io.ErrShortWrite", err) - } else if n != 2 { - t.Errorf("got l.Write(3/2) = (%d, _), want 2", n) - } - if l.N != 0 { - t.Errorf("got l.N = %d, want 0", l.N) - } - l.N = 1 - if n, err := l.Write([]byte{5}); err != nil { - t.Errorf("got l.Write(1/1) = (_, %s), want nil", err) - } else if n != 1 { - t.Errorf("got l.Write(1/1) = (%d, _), want 1", n) - } - if diff := cmp.Diff(b.Bytes(), []byte{0, 1, 2, 3, 4, 5}); diff != "" { - t.Errorf("%T wrote incorrect data: (-want +got):\n%s", l, diff) - } -} - -func TestSubnetContains(t *testing.T) { - tests := []struct { - s Address - m AddressMask - a Address - want bool - }{ - {"\xa0", "\xf0", "\x90", false}, - {"\xa0", "\xf0", "\xa0", true}, - {"\xa0", "\xf0", "\xa5", true}, - {"\xa0", "\xf0", "\xaf", true}, - {"\xa0", "\xf0", "\xb0", false}, - {"\xa0", "\xf0", "", false}, - {"\xa0", "\xf0", "\xa0\x00", false}, - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - s, err := NewSubnet(tt.s, tt.m) - if err != nil { - t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err) - continue - } - if got := s.Contains(tt.a); got != tt.want { - t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want) - } - } -} - -func TestSubnetBits(t *testing.T) { - tests := []struct { - a AddressMask - want1 int - want0 int - }{ - {"\x00", 0, 8}, - {"\x00\x00", 0, 16}, - {"\x36", 0, 8}, - {"\x5c", 0, 8}, - {"\x5c\x5c", 0, 16}, - {"\x5c\x36", 0, 16}, - {"\x36\x5c", 0, 16}, - {"\x36\x36", 0, 16}, - {"\xff", 8, 0}, - {"\xff\xff", 16, 0}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got1, got0 := s.Bits() - if got1 != tt.want1 || got0 != tt.want0 { - t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0) - } - } -} - -func TestSubnetPrefix(t *testing.T) { - tests := []struct { - a AddressMask - want int - }{ - {"\x00", 0}, - {"\x00\x00", 0}, - {"\x36", 0}, - {"\x86", 1}, - {"\xc5", 2}, - {"\xff\x00", 8}, - {"\xff\x36", 8}, - {"\xff\x8c", 9}, - {"\xff\xc8", 10}, - {"\xff", 8}, - {"\xff\xff", 16}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got := s.Prefix() - if got != tt.want { - t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want) - } - } -} - -func TestSubnetCreation(t *testing.T) { - tests := []struct { - a Address - m AddressMask - want error - }{ - {"\xa0", "\xf0", nil}, - {"\xa0\xa0", "\xf0", errSubnetLengthMismatch}, - {"\xaa", "\xf0", errSubnetAddressMasked}, - {"", "", nil}, - } - for _, tt := range tests { - if _, err := NewSubnet(tt.a, tt.m); err != tt.want { - t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want) - } - } -} - -func TestAddressString(t *testing.T) { - for _, want := range []string{ - // Taken from stdlib. - "2001:db8::123:12:1", - "2001:db8::1", - "2001:db8:0:1:0:1:0:1", - "2001:db8:1:0:1:0:1:0", - "2001::1:0:0:1", - "2001:db8:0:0:1::", - "2001:db8::1:0:0:1", - "2001:db8::a:b:c:d", - - // Leading zeros. - "::1", - // Trailing zeros. - "8::", - // No zeros. - "1:1:1:1:1:1:1:1", - // Longer sequence is after other zeros, but not at the end. - "1:0:0:1::1", - // Longer sequence is at the beginning, shorter sequence is at - // the end. - "::1:1:1:0:0", - // Longer sequence is not at the beginning, shorter sequence is - // at the end. - "1::1:1:0:0", - // Longer sequence is at the beginning, shorter sequence is not - // at the end. - "::1:1:0:0:1", - // Neither sequence is at an end, longer is after shorter. - "1:0:0:1::1", - // Shorter sequence is at the beginning, longer sequence is not - // at the end. - "0:0:1:1::1", - // Shorter sequence is at the beginning, longer sequence is at - // the end. - "0:0:1:1:1::", - // Short sequences at both ends, longer one in the middle. - "0:1:1::1:1:0", - // Short sequences at both ends, longer one in the middle. - "0:1::1:0:0", - // Short sequences at both ends, longer one in the middle. - "0:0:1::1:0", - // Longer sequence surrounded by shorter sequences, but none at - // the end. - "1:0:1::1:0:1", - } { - addr := Address(net.ParseIP(want)) - if got := addr.String(); got != want { - t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want) - } - } -} - -func TestAddressWithPrefixSubnet(t *testing.T) { - tests := []struct { - addr Address - prefixLen int - subnetAddr Address - subnetMask AddressMask - }{ - {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, - {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, - {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, - {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, - } - for _, tt := range tests { - ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} - gotSubnet := ap.Subnet() - wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) - if err != nil { - t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) - continue - } - if gotSubnet != wantSubnet { - t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) - } - } -} - -func TestAddressUnspecified(t *testing.T) { - tests := []struct { - addr Address - unspecified bool - }{ - { - addr: "", - unspecified: true, - }, - { - addr: "\x00", - unspecified: true, - }, - { - addr: "\x01", - unspecified: false, - }, - { - addr: "\x00\x00", - unspecified: true, - }, - { - addr: "\x01\x00", - unspecified: false, - }, - { - addr: "\x00\x01", - unspecified: false, - }, - { - addr: "\x01\x01", - unspecified: false, - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) { - if got := test.addr.Unspecified(); got != test.unspecified { - t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified) - } - }) - } -} - -func TestAddressMatchingPrefix(t *testing.T) { - tests := []struct { - addrA Address - addrB Address - prefix uint8 - }{ - { - addrA: "\x01\x01", - addrB: "\x01\x01", - prefix: 16, - }, - { - addrA: "\x01\x01", - addrB: "\x01\x00", - prefix: 15, - }, - { - addrA: "\x01\x01", - addrB: "\x81\x00", - prefix: 0, - }, - { - addrA: "\x01\x01", - addrB: "\x01\x80", - prefix: 8, - }, - { - addrA: "\x01\x01", - addrB: "\x02\x80", - prefix: 6, - }, - } - - for _, test := range tests { - if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix { - t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix) - } - } -} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD deleted file mode 100644 index 181ef799e..000000000 --- a/pkg/tcpip/tests/integration/BUILD +++ /dev/null @@ -1,141 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "forward_test", - size = "small", - srcs = ["forward_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "iptables_test", - size = "small", - srcs = ["iptables_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/udp", - ], -) - -go_test( - name = "link_resolution_test", - size = "small", - srcs = ["link_resolution_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/pipe", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) - -go_test( - name = "loopback_test", - size = "small", - srcs = ["loopback_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "multicast_broadcast_test", - size = "small", - srcs = ["multicast_broadcast_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "route_test", - size = "small", - srcs = ["route_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go deleted file mode 100644 index 92fa6257d..000000000 --- a/pkg/tcpip/tests/integration/forward_test.go +++ /dev/null @@ -1,690 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package forward_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ttl = 64 - -var ( - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") -) - -func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) -} - -func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) -} - -func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo))) -} - -func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest))) -} - -func TestForwarding(t *testing.T) { - const listenPort = 8080 - - type endpointAndAddresses struct { - serverEP tcpip.Endpoint - serverAddr tcpip.Address - serverReadableCH chan struct{} - - clientEP tcpip.Endpoint - clientAddr tcpip.Address - clientReadableCH chan struct{} - } - - newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { - t.Helper() - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := s.NewEndpoint(transProto, netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) - } - - t.Cleanup(func() { - wq.EventUnregister(&we) - }) - - return ep, ch - } - - tests := []struct { - name string - epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses - }{ - { - name: "IPv4 host1 server with host2 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv6 host2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv4 host2 server with routerNIC1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv6 routerNIC2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - } - - subTests := []struct { - name string - proto tcpip.TransportProtocolNumber - expectedConnectErr tcpip.Error - setupServer func(t *testing.T, ep tcpip.Endpoint) - setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) - needRemoteAddr bool - }{ - { - name: "UDP", - proto: udp.ProtocolNumber, - expectedConnectErr: nil, - setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - if err := ep.Connect(clientAddr); err != nil { - t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) - } - return nil, nil - }, - needRemoteAddr: true, - }, - { - name: "TCP", - proto: tcp.ProtocolNumber, - expectedConnectErr: &tcpip.ErrConnectStarted{}, - setupServer: func(t *testing.T, ep tcpip.Endpoint) { - t.Helper() - - if err := ep.Listen(1); err != nil { - t.Fatalf("ep.Listen(1): %s", err) - } - }, - setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - var addr tcpip.FullAddress - for { - newEP, wq, err := ep.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Accept(_): %s", err) - } - if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( - "NIC", - )); diff != "" { - t.Errorf("accepted address mismatch (-want +got):\n%s", diff) - } - - we, newCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - return newEP, newCH - } - }, - needRemoteAddr: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - } - - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) - defer epsAndAddrs.serverEP.Close() - defer epsAndAddrs.clientEP.Close() - - serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} - if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } - - if subTest.setupServer != nil { - subTest.setupServer(t, epsAndAddrs.serverEP) - } - { - err := epsAndAddrs.clientEP.Connect(serverAddr) - if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) - } - } - if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { - t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) - } else { - clientAddr = addr - clientAddr.NIC = 0 - } - - serverEP := epsAndAddrs.serverEP - serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil { - defer ep.Close() - serverEP = ep - serverCH = ch - } - - write := func(ep tcpip.Endpoint, data []byte) { - t.Helper() - - var r bytes.Reader - r.Reset(data) - var wOpts tcpip.WriteOptions - n, err := ep.Write(&r, wOpts) - if err != nil { - t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) - } - } - - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data) - - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { - t.Helper() - - var buf bytes.Buffer - var res tcpip.ReadResult - for { - var err tcpip.Error - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err = ep.Read(&buf, opts) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) - } - break - } - - readResult := tcpip.ReadResult{ - Count: len(data), - Total: len(data), - } - if subTest.needRemoteAddr { - readResult.RemoteAddr = expectedFrom - } - if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() - } - } - - read(serverCH, serverEP, data, clientAddr) - - data = []byte{5, 6, 7, 8, 9, 10, 11, 12} - write(serverEP, data) - read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) - }) - } - }) - } -} - -func TestMulticastForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - ) - - var ( - ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") - ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") - - ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") - ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") - ) - - tests := []struct { - name string - srcAddr, dstAddr tcpip.Address - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - expectForward bool - checker func(*testing.T, []byte) - }{ - { - name: "IPv4 link-local multicast destination", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4LinkLocalMulticastAddr, - rx: rxICMPv4EchoRequest, - expectForward: false, - }, - { - name: "IPv4 link-local source", - srcAddr: ipv4LinkLocalUnicastAddr, - dstAddr: utils.RemoteIPv4Addr, - rx: rxICMPv4EchoRequest, - expectForward: false, - }, - { - name: "IPv4 link-local destination", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4LinkLocalUnicastAddr, - rx: rxICMPv4EchoRequest, - expectForward: false, - }, - { - name: "IPv4 non-link-local unicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - rx: rxICMPv4EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv4 non-link-local multicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4GlobalMulticastAddr, - rx: rxICMPv4EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) - }, - }, - - { - name: "IPv6 link-local multicast destination", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6LinkLocalMulticastAddr, - rx: rxICMPv6EchoRequest, - expectForward: false, - }, - { - name: "IPv6 link-local source", - srcAddr: ipv6LinkLocalUnicastAddr, - dstAddr: utils.RemoteIPv6Addr, - rx: rxICMPv6EchoRequest, - expectForward: false, - }, - { - name: "IPv6 link-local destination", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6LinkLocalUnicastAddr, - rx: rxICMPv6EchoRequest, - expectForward: false, - }, - { - name: "IPv6 non-link-local unicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - rx: rxICMPv6EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv6 non-link-local multicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6GlobalMulticastAddr, - rx: rxICMPv6EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) - } - - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) - } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID2, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID2, - }, - }) - - test.rx(e1, test.srcAddr, test.dstAddr) - - p, ok := e2.Read() - if ok != test.expectForward { - t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward) - } - - if test.expectForward { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - }) - } -} - -func TestPerInterfaceForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - ) - - tests := []struct { - name string - srcAddr, dstAddr tcpip.Address - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - checker func(*testing.T, []byte) - }{ - { - name: "IPv4 unicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - rx: rxICMPv4EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv4 multicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4GlobalMulticastAddr, - rx: rxICMPv4EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) - }, - }, - - { - name: "IPv6 unicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - rx: rxICMPv6EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv6 multicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6GlobalMulticastAddr, - rx: rxICMPv6EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) - }, - }, - } - - netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - // ARP is not used in this test but it is a network protocol that does - // not support forwarding. We install the protocol to make sure that - // forwarding information for a NIC is only reported for network - // protocols that support forwarding. - arp.NewProtocol, - - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) - } - - for _, add := range [...]struct { - nicID tcpip.NICID - addr tcpip.ProtocolAddress - }{ - { - nicID: nicID1, - addr: utils.RouterNIC1IPv4Addr, - }, - { - nicID: nicID1, - addr: utils.RouterNIC1IPv6Addr, - }, - { - nicID: nicID2, - addr: utils.RouterNIC2IPv4Addr, - }, - { - nicID: nicID2, - addr: utils.RouterNIC2IPv6Addr, - }, - } { - if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) - } - } - - // Only enable forwarding on NIC1 and make sure that only packets arriving - // on NIC1 are forwarded. - for _, netProto := range netProtos { - if err := s.SetNICForwarding(nicID1, netProto, true); err != nil { - t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err) - } - } - - nicsInfo := s.NICInfo() - for _, subTest := range [...]struct { - nicID tcpip.NICID - nicEP *channel.Endpoint - otherNICID tcpip.NICID - otherNICEP *channel.Endpoint - expectForwarding bool - }{ - { - nicID: nicID1, - nicEP: e1, - otherNICID: nicID2, - otherNICEP: e2, - expectForwarding: true, - }, - { - nicID: nicID2, - nicEP: e2, - otherNICID: nicID2, - otherNICEP: e1, - expectForwarding: false, - }, - } { - t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) { - nicInfo, ok := nicsInfo[subTest.nicID] - if !ok { - t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo) - } else { - forwarding := make(map[tcpip.NetworkProtocolNumber]bool) - for _, netProto := range netProtos { - forwarding[netProto] = subTest.expectForwarding - } - - if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" { - t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff) - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: subTest.otherNICID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: subTest.otherNICID, - }, - }) - - test.rx(subTest.nicEP, test.srcAddr, test.dstAddr) - if p, ok := subTest.nicEP.Read(); ok { - t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p) - } - if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding { - t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding) - } else if subTest.expectForwarding { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go deleted file mode 100644 index f9ab7d0af..000000000 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ /dev/null @@ -1,1134 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -type inputIfNameMatcher struct { - name string -} - -var _ stack.Matcher = (*inputIfNameMatcher)(nil) - -func (*inputIfNameMatcher) Name() string { - return "inputIfNameMatcher" -} - -func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { - return (hook == stack.Input && im.name != "" && im.name == inNicName), false -} - -const ( - nicID = 1 - nicName = "nic1" - anotherNicName = "nic2" - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = "\x0a\x00\x00\x01" - dstAddrV4 = "\x0a\x00\x00\x02" - srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - payloadSize = 20 -) - -func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { - t.Helper() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) - } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) - } - return s, e -} - -func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { - t.Helper() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - e := channel.New(0, header.IPv4MinimumMTU, linkAddr) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) - } - return s, e -} - -func genPacketV6() *stack.PacketBuffer { - pktSize := header.IPv6MinimumSize + payloadSize - hdr := buffer.NewPrependable(pktSize) - ip := header.IPv6(hdr.Prepend(pktSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadSize, - TransportProtocol: 99, - HopLimit: 255, - SrcAddr: srcAddrV6, - DstAddr: dstAddrV6, - }) - vv := hdr.View().ToVectorisedView() - return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) -} - -func genPacketV4() *stack.PacketBuffer { - pktSize := header.IPv4MinimumSize + payloadSize - hdr := buffer.NewPrependable(pktSize) - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&header.IPv4Fields{ - TOS: 0, - TotalLength: uint16(pktSize), - ID: 1, - Flags: 0, - FragmentOffset: 16, - TTL: 48, - Protocol: 99, - SrcAddr: srcAddrV4, - DstAddr: dstAddrV4, - }) - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - vv := hdr.View().ToVectorisedView() - return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) -} - -func TestIPTablesStatsForInput(t *testing.T) { - tests := []struct { - name string - setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) - setupFilter func(*testing.T, *stack.Stack) - genPacket func() *stack.PacketBuffer - proto tcpip.NetworkProtocolNumber - expectReceived int - expectInputDropped int - }{ - { - name: "IPv6 Accept", - setupStack: genStackV6, - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept", - setupStack: genStackV4, - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv6 Drop (input interface matches)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv4 Drop (input interface matches)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv6 Accept (input interface does not match)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept (input interface does not match)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv6 Drop (input interface does not match but invert is true)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - InputInterface: anotherNicName, - InputInterfaceInvert: true, - } - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv4 Drop (input interface does not match but invert is true)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - InputInterface: anotherNicName, - InputInterfaceInvert: true, - } - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv6 Accept (input interface does not match using a matcher)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept (input interface does not match using a matcher)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := test.setupStack(t) - test.setupFilter(t, s) - e.InjectInbound(test.proto, test.genPacket()) - - if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { - t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) - } - if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { - t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) - } - }) - } -} - -var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) - -// channelEndpointWithoutWritePacket is a channel endpoint that does not support -// stack.LinkEndpoint.WritePacket. -type channelEndpointWithoutWritePacket struct { - *channel.Endpoint - - t *testing.T -} - -func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") - return &tcpip.ErrNotSupported{} -} - -var _ stack.Matcher = (*udpSourcePortMatcher)(nil) - -type udpSourcePortMatcher struct { - port uint16 -} - -func (*udpSourcePortMatcher) Name() string { - return "udpSourcePortMatcher" -} - -func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { - udp := header.UDP(pkt.TransportHeader().View()) - if len(udp) < header.UDPMinimumSize { - // Drop immediately as the packet is invalid. - return false, true - } - - return udp.SourcePort() == m.port, false -} - -func TestIPTableWritePackets(t *testing.T) { - const ( - nicID = 1 - - dropLocalPort = utils.LocalPort - 1 - acceptPackets = 2 - dropPackets = 3 - ) - - udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { - u := header.UDP(hdr) - u.Encode(&header.UDPFields{ - SrcPort: srcPort, - DstPort: dstPort, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) - sum = header.Checksum(hdr, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - } - - tests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack) - genPacket func(*stack.Route) stack.PacketBufferList - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectSent uint64 - expectOutputDropped uint64 - }{ - { - name: "IPv4 Accept", - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - - return pkts - }, - proto: header.IPv4ProtocolNumber, - remoteAddr: dstAddrV4, - expectSent: 1, - expectOutputDropped: 0, - }, - { - name: "IPv4 Drop Other Port", - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - - table := stack.Table{ - Rules: []stack.Rule{ - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, - Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - } - - if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) - } - }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - for i := 0; i < acceptPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - for i := 0; i < dropPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - - return pkts - }, - proto: header.IPv4ProtocolNumber, - remoteAddr: dstAddrV4, - expectSent: acceptPackets, - expectOutputDropped: dropPackets, - }, - { - name: "IPv6 Accept", - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - - return pkts - }, - proto: header.IPv6ProtocolNumber, - remoteAddr: dstAddrV6, - expectSent: 1, - expectOutputDropped: 0, - }, - { - name: "IPv6 Drop Other Port", - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - - table := stack.Table{ - Rules: []stack.Rule{ - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, - Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - } - - if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) - } - }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - for i := 0; i < acceptPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - for i := 0; i < dropPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - - return pkts - }, - proto: header.IPv6ProtocolNumber, - remoteAddr: dstAddrV6, - expectSent: acceptPackets, - expectOutputDropped: dropPackets, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channelEndpointWithoutWritePacket{ - Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), - t: t, - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - test.setupFilter(t, s) - - r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) - } - defer r.Release() - - pkts := test.genPacket(r) - pktsLen := pkts.Len() - if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ - Protocol: header.UDPProtocolNumber, - TTL: 64, - }); err != nil { - t.Fatalf("WritePackets(...): %s", err) - } else if n != pktsLen { - t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) - } - - if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { - t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) - } - if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { - t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) - } - }) - } -} - -const ttl = 64 - -var ( - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") -) - -func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoReply(e, src, dst, ttl) -} - -func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoReply(e, src, dst, ttl) -} - -func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply))) -} - -func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply))) -} - -func boolToInt(v bool) uint64 { - if v { - return 1 - } - return 0 -} - -func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[hook] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } -} - -func TestForwardingHook(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Name = "nic1" - nic2Name = "nic2" - - otherNICName = "otherNIC" - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - local bool - srcAddr, dstAddr tcpip.Address - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - checker func(*testing.T, []byte) - }{ - { - name: "IPv4 remote", - netProto: ipv4.ProtocolNumber, - local: false, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - rx: rxICMPv4EchoReply, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv4 local", - netProto: ipv4.ProtocolNumber, - local: true, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr.Address, - rx: rxICMPv4EchoReply, - }, - { - name: "IPv6 remote", - netProto: ipv6.ProtocolNumber, - local: false, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - rx: rxICMPv6EchoReply, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv6 local", - netProto: ipv6.ProtocolNumber, - local: true, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr.Address, - rx: rxICMPv6EchoReply, - }, - } - - subTests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) - expectForward bool - }{ - { - name: "Accept", - setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, - expectForward: true, - }, - - { - name: "Drop", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), - expectForward: false, - }, - { - name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), - expectForward: false, - }, - { - name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), - expectForward: false, - }, - { - name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), - expectForward: false, - }, - - { - name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), - expectForward: true, - }, - { - name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), - expectForward: true, - }, - { - name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), - expectForward: true, - }, - { - name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), - expectForward: true, - }, - { - name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), - expectForward: true, - }, - - { - name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), - expectForward: true, - }, - { - name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), - expectForward: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - - subTest.setupFilter(t, s, test.netProto) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) - } - - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) - } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID2, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID2, - }, - }) - - test.rx(e1, test.srcAddr, test.dstAddr) - - expectTransmitPacket := subTest.expectForward && !test.local - - ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) - } - ep1Stats := ep1.Stats() - ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) - } - ip1Stats := ipEP1Stats.IPStats() - - if got := ip1Stats.PacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) - } - if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) - } - if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want { - t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want) - } - if got := ip1Stats.PacketsSent.Value(); got != 0 { - t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got) - } - - ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) - } - ep2Stats := ep2.Stats() - ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) - } - ip2Stats := ipEP2Stats.IPStats() - if got := ip2Stats.PacketsReceived.Value(); got != 0 { - t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) - } - if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want { - t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want) - } - if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want { - t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want) - } - - p, ok := e2.Read() - if ok != expectTransmitPacket { - t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket) - } - if expectTransmitPacket { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - }) - } - }) - } -} - -func TestInputHookWithLocalForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Name = "nic1" - nic2Name = "nic2" - - otherNICName = "otherNIC" - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - rx func(*channel.Endpoint) - checker func(*testing.T, []byte) - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - rx: func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) - }, - checker: func(t *testing.T, b []byte) { - checker.IPv4(t, b, - checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), - checker.DstAddr(utils.RemoteIPv4Addr), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply))) - }, - }, - { - name: "IPv6", - netProto: ipv6.ProtocolNumber, - rx: func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) - }, - checker: func(t *testing.T, b []byte) { - checker.IPv6(t, b, - checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), - checker.DstAddr(utils.RemoteIPv6Addr), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply))) - }, - }, - } - - subTests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) - expectDrop bool - }{ - { - name: "Accept", - setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, - expectDrop: false, - }, - - { - name: "Drop", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), - expectDrop: true, - }, - { - name: "Drop with input NIC filtering on arrival NIC", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), - expectDrop: true, - }, - { - name: "Drop with input NIC filtering on delivered NIC", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), - expectDrop: false, - }, - - { - name: "Drop with input NIC filtering on other NIC", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), - expectDrop: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - - subTest.setupFilter(t, s, test.netProto) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) - } - if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) - } - if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) - } - if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID1, - }, - }) - - test.rx(e1) - - ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) - } - ep1Stats := ep1.Stats() - ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) - } - ip1Stats := ipEP1Stats.IPStats() - - if got := ip1Stats.PacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) - } - if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) - } - if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { - t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) - } - - ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) - } - ep2Stats := ep2.Stats() - ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) - } - ip2Stats := ipEP2Stats.IPStats() - if got := ip2Stats.PacketsReceived.Value(); got != 0 { - t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) - } - if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { - t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) - } - if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { - t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) - } - if got := ip2Stats.PacketsSent.Value(); got != 0 { - t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) - } - - if p, ok := e1.Read(); ok == subTest.expectDrop { - t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) - } else if !subTest.expectDrop { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - if p, ok := e2.Read(); ok { - t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go deleted file mode 100644 index 27caa0c28..000000000 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ /dev/null @@ -1,1640 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package link_resolution_test - -import ( - "bytes" - "fmt" - "net" - "runtime" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/pipe" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) { - host1Stack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, host2NIC := pipe.New(utils.LinkAddr1, utils.LinkAddr2) - - if err := host1Stack.CreateNIC(host1NICID, utils.NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := host2Stack.CreateNIC(host2NICID, utils.NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: utils.Ipv4Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: utils.Ipv6Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: utils.Ipv4Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: utils.Ipv6Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - }) - - return host1Stack, host2Stack -} - -// TestPing tests that two hosts can ping eachother when link resolution is -// enabled. -func TestPing(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - - // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo - // request/reply packets. - icmpDataOffset = 8 - ) - - tests := []struct { - name string - transProto tcpip.TransportProtocolNumber - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - icmpBuf func(*testing.T) []byte - }{ - { - name: "IPv4 Ping", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) []byte { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) - hdr.SetType(header.ICMPv4Echo) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return hdr - }, - }, - { - name: "IPv6 Ping", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) []byte { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) - hdr.SetType(header.ICMPv6EchoRequest) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return hdr - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - - var wq waiter.Queue - we, waiterCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - icmpBuf := test.icmpBuf(t) - var r bytes.Reader - r.Reset(icmpBuf) - wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(_, _): %s", err) - } else if want := int64(len(icmpBuf)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) - } - - // Wait for the endpoint to be readable. - <-waiterCH - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type transportError struct { - origin tcpip.SockErrOrigin - typ uint8 - code uint8 - info uint32 - kind stack.TransportErrorKind -} - -func TestTCPLinkResolutionFailure(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedWriteErr tcpip.Error - sockError tcpip.SockError - transErr transportError - }{ - { - name: "IPv4 with resolvable remote", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv6 with resolvable remote", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv4 without resolvable remote", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: &tcpip.ErrNoRoute{}, - sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - Dst: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv4Addr3.AddressWithPrefix.Address, - Port: 1234, - }, - Offender: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv4Addr1.AddressWithPrefix.Address, - }, - NetProto: ipv4.ProtocolNumber, - }, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4HostUnreachable), - kind: stack.DestinationHostUnreachableTransportError, - }, - }, - { - name: "IPv6 without resolvable remote", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: &tcpip.ErrNoRoute{}, - sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - Dst: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv6Addr3.AddressWithPrefix.Address, - Port: 1234, - }, - Offender: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv6Addr1.AddressWithPrefix.Address, - }, - NetProto: ipv6.ProtocolNumber, - }, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6DstUnreachable), - code: uint8(header.ICMPv6AddressUnreachable), - kind: stack.DestinationHostUnreachableTransportError, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - Clock: clock, - } - - host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) - - var listenerWQ waiter.Queue - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) - } - defer listenerEP.Close() - - listenerAddr := tcpip.FullAddress{Port: 1234} - if err := listenerEP.Bind(listenerAddr); err != nil { - t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) - } - - if err := listenerEP.Listen(1); err != nil { - t.Fatalf("listenerEP.Listen(1): %s", err) - } - - var clientWQ waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&we, waiter.WritableEvents|waiter.EventErr) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) - } - defer clientEP.Close() - - sockOpts := clientEP.SocketOptions() - sockOpts.SetRecvError(true) - - remoteAddr := listenerAddr - remoteAddr.Addr = test.remoteAddr - { - err := clientEP.Connect(remoteAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) - } - } - - // Wait for an error due to link resolution failing, or the endpoint to be - // writable. - if test.expectedWriteErr != nil { - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) - } - clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) - } else { - clock.RunImmediatelyScheduledJobs() - } - <-ch - - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - _, err := clientEP.Write(&r, wOpts) - if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { - t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) - } - } - - if test.expectedWriteErr == nil { - return - } - - sockErr := sockOpts.DequeueErr() - if sockErr == nil { - t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil") - } - - sockErrCmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b tcpip.Error) bool { - // tcpip.Error holds an unexported field but the errors netstack uses - // are pre defined so we can simply compare pointers. - return a == b - }), - checker.IgnoreCmpPath( - // Ignore the payload since we do not know the TCP seq/ack numbers. - "Payload", - // Ignore the cause since we will compare its properties separately - // since the concrete type of the cause is unknown. - "Cause", - ), - } - - if addr, err := clientEP.GetLocalAddress(); err != nil { - t.Fatalf("clientEP.GetLocalAddress(): %s", err) - } else { - test.sockError.Offender.Port = addr.Port - } - if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { - t.Errorf("socket error mismatch (-want +got):\n%s", diff) - } - - transErr, ok := sockErr.Cause.(stack.TransportError) - if !ok { - t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) - } - if diff := cmp.Diff( - test.transErr, - transportError{ - origin: transErr.Origin(), - typ: transErr.Type(), - code: transErr.Code(), - info: transErr.Info(), - kind: transErr.Kind(), - }, - cmp.AllowUnexported(transportError{}), - ); diff != "" { - t.Errorf("socket error mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestForwardingWithLinkResolutionFailure(t *testing.T) { - const ( - incomingNICID = 1 - outgoingNICID = 2 - ttl = 2 - expectedHostUnreachableErrorCount = 1 - ) - outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") - - rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) - } - - rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) - } - - arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { - if request.Proto != arp.ProtocolNumber { - t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) - } - if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - rep := header.ARP(request.Pkt.NetworkHeader().View()) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) - } - } - - ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { - if request.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) - } - - snmc := header.SolicitedNodeAddr(dst) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), - checker.SrcAddr(src), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(dst), - )) - } - - icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ipv4.DefaultTTL), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4HostUnreachable), - ), - ) - } - - icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ipv6.DefaultTTL), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6AddressUnreachable), - ), - ) - } - - tests := []struct { - name string - networkProtocolFactory []stack.NetworkProtocolFactory - networkProtocolNumber tcpip.NetworkProtocolNumber - sourceAddr tcpip.Address - destAddr tcpip.Address - incomingAddr tcpip.AddressWithPrefix - outgoingAddr tcpip.AddressWithPrefix - transportProtocol func(*stack.Stack) stack.TransportProtocol - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) - icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) - mtu uint32 - }{ - { - name: "IPv4 Host unreachable", - networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - networkProtocolNumber: header.IPv4ProtocolNumber, - sourceAddr: tcptestutil.MustParse4("10.0.0.2"), - destAddr: tcptestutil.MustParse4("11.0.0.2"), - incomingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - }, - outgoingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), - PrefixLen: 8, - }, - transportProtocol: icmp.NewProtocol4, - linkResolutionRequestChecker: arpChecker, - icmpReplyChecker: icmpv4Checker, - rx: rxICMPv4EchoRequest, - mtu: ipv4.MaxTotalSize, - }, - { - name: "IPv6 Host unreachable", - networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - networkProtocolNumber: header.IPv6ProtocolNumber, - sourceAddr: tcptestutil.MustParse6("10::2"), - destAddr: tcptestutil.MustParse6("11::2"), - incomingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10::1").To16()), - PrefixLen: 64, - }, - outgoingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11::1").To16()), - PrefixLen: 64, - }, - transportProtocol: icmp.NewProtocol6, - linkResolutionRequestChecker: ndpChecker, - icmpReplyChecker: icmpv6Checker, - rx: rxICMPv6EchoRequest, - mtu: header.IPv6MinimumMTU, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - - s := stack.New(stack.Options{ - NetworkProtocols: test.networkProtocolFactory, - TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, - Clock: clock, - }) - - // Set up endpoint through which we will receive packets. - incomingEndpoint := channel.New(1, test.mtu, "") - if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) - } - incomingProtoAddr := tcpip.ProtocolAddress{ - Protocol: test.networkProtocolNumber, - AddressWithPrefix: test.incomingAddr, - } - if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) - } - - // Set up endpoint through which we will attempt to forward packets. - outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) - outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) - } - outgoingProtoAddr := tcpip.ProtocolAddress{ - Protocol: test.networkProtocolNumber, - AddressWithPrefix: test.outgoingAddr, - } - if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: test.incomingAddr.Subnet(), - NIC: incomingNICID, - }, - { - Destination: test.outgoingAddr.Subnet(), - NIC: outgoingNICID, - }, - }) - - if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) - } - - test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) - - nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) - if err != nil { - t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) - } - // Trigger the first packet on the endpoint. - clock.RunImmediatelyScheduledJobs() - - for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { - request, ok := outgoingEndpoint.Read() - if !ok { - t.Fatal("expected ARP packet through outgoing NIC") - } - - test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) - - // Advance the clock the span of one request timeout. - clock.Advance(nudConfigs.RetransmitTimer) - } - - // Next, we make a blocking read to retrieve the error packet. This is - // necessary because outgoing packets are dequeued asynchronously when - // link resolution fails, and this dequeue is what triggers the ICMP - // error. - reply, ok := incomingEndpoint.Read() - if !ok { - t.Fatal("expected ICMP packet through incoming NIC") - } - - test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) - - // Since link resolution failed, we don't expect the packet to be - // forwarded. - forwardedPacket, ok := outgoingEndpoint.Read() - if ok { - t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) - } - - if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { - t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) - } - }) - } -} - -func TestGetLinkAddress(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr, localAddr tcpip.Address - expectedErr tcpip.Error - }{ - { - name: "IPv4 resolvable", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedErr: nil, - }, - { - name: "IPv6 resolvable", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedErr: nil, - }, - { - name: "IPv4 not resolvable", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrTimeout{}, - }, - { - name: "IPv6 not resolvable", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrTimeout{}, - }, - { - name: "IPv4 bad local address", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - localAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "IPv6 bad local address", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - localAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, test.localAddr, test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - wantRes := stack.LinkResolutionResult{Err: test.expectedErr} - if test.expectedErr == nil { - wantRes.LinkAddress = utils.LinkAddr2 - } - - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) - } - - clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) - select { - case got := <-ch: - if diff := cmp.Diff(wantRes, got); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("event didn't arrive") - } - }) - } -} - -func TestRouteResolvedFields(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - localAddr tcpip.Address - remoteAddr tcpip.Address - immediatelyResolvable bool - expectedErr tcpip.Error - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4 immediately resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: header.IPv4AllSystems, - immediatelyResolvable: true, - expectedErr: nil, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), - }, - { - name: "IPv6 immediately resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: header.IPv6AllNodesMulticastAddress, - immediatelyResolvable: true, - expectedErr: nil, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - }, - { - name: "IPv4 resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: nil, - expectedLinkAddr: utils.LinkAddr2, - }, - { - name: "IPv6 resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: nil, - expectedLinkAddr: utils.LinkAddr2, - }, - { - name: "IPv4 not resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: &tcpip.ErrTimeout{}, - }, - { - name: "IPv6 not resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: &tcpip.ErrTimeout{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - r, err := host1Stack.FindRoute(host1NICID, test.localAddr, test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, %s, %s, %d, false): %s", host1NICID, test.localAddr, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - var wantRouteInfo stack.RouteInfo - wantRouteInfo.LocalLinkAddress = utils.LinkAddr1 - wantRouteInfo.LocalAddress = test.localAddr - wantRouteInfo.RemoteAddress = test.remoteAddr - wantRouteInfo.NetProto = test.netProto - wantRouteInfo.Loop = stack.PacketOut - wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - - ch := make(chan stack.ResolvedFieldsResult, 1) - - if !test.immediatelyResolvable { - wantUnresolvedRouteInfo := wantRouteInfo - wantUnresolvedRouteInfo.RemoteLinkAddress = "" - - err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) - } - clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) - - select { - case got := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("event didn't arrive") - } - - if test.expectedErr != nil { - return - } - - // At this point the neighbor table should be populated so the route - // should be immediately resolvable. - } - - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }); err != nil { - t.Errorf("r.ResolvedFields(_): %s", err) - } - select { - case routeResolveRes := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: nil}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected route to be immediately resolvable") - } - }) - } -} - -func TestWritePacketsLinkResolution(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedWriteErr tcpip.Error - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv6", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - } - - host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) - - var serverWQ waiter.Queue - serverWE, serverCH := waiter.NewChannelEntry(nil) - serverWQ.EventRegister(&serverWE, waiter.ReadableEvents) - serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) - } - defer serverEP.Close() - - serverAddr := tcpip.FullAddress{Port: 1234} - if err := serverEP.Bind(serverAddr); err != nil { - t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) - } - - r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - data := []byte{1, 2} - var pkts stack.PacketBufferList - for _, d := range data { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: buffer.View([]byte{d}).ToVectorisedView(), - }) - pkt.TransportProtocolNumber = udp.ProtocolNumber - length := uint16(pkt.Size()) - udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: serverAddr.Port, - Length: length, - }) - xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - - pkts.PushBack(pkt) - } - - params := stack.NetworkHeaderParams{ - Protocol: udp.ProtocolNumber, - TTL: 64, - TOS: stack.DefaultTOS, - } - - if n, err := r.WritePackets(pkts, params); err != nil { - t.Fatalf("r.WritePackets(_, %#v): %s", params, err) - } else if want := pkts.Len(); want != n { - t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want) - } - - var writer bytes.Buffer - count := 0 - for { - var rOpts tcpip.ReadOptions - res, err := serverEP.Read(&writer, rOpts) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Should not have anymore bytes to read after we read the sent - // number of bytes. - if count == len(data) { - break - } - - <-serverCH - continue - } - - t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) - } - count += res.Count - } - - if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { - t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) - } - if diff := cmp.Diff(data, writer.Bytes()); diff != "" { - t.Errorf("read bytes mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type eventType int - -const ( - entryAdded eventType = iota - entryChanged - entryRemoved -) - -func (t eventType) String() string { - switch t { - case entryAdded: - return "add" - case entryChanged: - return "change" - case entryRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -type eventInfo struct { - eventType eventType - nicID tcpip.NICID - entry stack.NeighborEntry -} - -func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) -} - -var _ stack.NUDDispatcher = (*nudDispatcher)(nil) - -type nudDispatcher struct { - c chan eventInfo -} - -func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryChanged, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryRemoved, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) expectEvent(want eventInfo) error { - select { - case got := <-d.c: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - return nil - default: - return fmt.Errorf("event didn't arrive") - } -} - -// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it -// that the neighbor used for a route is reachable. -func TestTCPConfirmNeighborReachability(t *testing.T) { - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) - isHost1Listener bool - }{ - { - name: "IPv4 active connection through neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv6 active connection through neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv4 active connection to neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv6 active connection to neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv4 passive connection to neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv6 passive connection to neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv4 passive connection through neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv6 passive connection through neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - nudDisp := nudDispatcher{ - c: make(chan eventInfo, 3), - } - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - Clock: clock, - } - host1StackOpts := stackOpts - host1StackOpts.NUDDisp = &nudDisp - - host1Stack := stack.New(host1StackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - // Add a reachable dynamic entry to our neighbor table for the remote. - { - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(utils.Host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", utils.Host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: utils.LinkAddr2, Err: nil}, <-ch); diff != "" { - t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) - } - } - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryAdded, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, - }); err != nil { - t.Fatalf("error waiting for initial NUD event: %s", err) - } - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - - // Wait for the remote's neighbor entry to be stale before creating a - // TCP connection from host1 to some remote. - nudConfigs, err := host1Stack.NUDConfigurations(utils.Host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", utils.Host1NICID, test.netProto, err) - } - // The maximum reachable time for a neighbor is some maximum random factor - // applied to the base reachable time. - // - // See NUDConfigurations.BaseReachableTime for more information. - maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) - clock.Advance(maxReachableTime) - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for stale NUD event: %s", err) - } - - listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer clientEP.Close() - listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} - if err := listenerEP.Bind(listenerAddr); err != nil { - t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) - } - if err := listenerEP.Listen(1); err != nil { - t.Fatalf("listenerEP.Listen(1): %s", err) - } - { - err := clientEP.Connect(listenerAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) - } - } - - // Wait for the TCP handshake to complete then make sure the neighbor is - // reachable without entering the probe state as TCP should provide NUD - // with confirmation that the neighbor is reachable (indicated by a - // successful 3-way handshake). - <-clientCH - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for delay NUD event: %s", err) - } - <-listenerCH - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - - peerEP, peerWQ, err := listenerEP.Accept(nil) - if err != nil { - t.Fatalf("listenerEP.Accept(): %s", err) - } - defer peerEP.Close() - peerWE, peerCH := waiter.NewChannelEntry(nil) - peerWQ.EventRegister(&peerWE, waiter.ReadableEvents) - - // Wait for the neighbor to be stale again then send data to the remote. - // - // On successful transmission, the neighbor should become reachable - // without probing the neighbor as a TCP ACK would be received which is an - // indication of the neighbor being reachable. - clock.Advance(maxReachableTime) - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for stale NUD event: %s", err) - } - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) - } - } - // Heads up, there is a race here. - // - // Incoming TCP segments are handled in - // tcp.(*endpoint).handleSegmentLocked: - // - // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the - // segment queue and notifies waiting readers (such as this channel) - // - // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment - // and notifies the NUD machinery that the peer is reachable - // - // Thus we must permit a delay between the readable signal and the - // expected NUD event. - // - // At the time of writing, this race is reliably hit with gotsan. - <-peerCH - for len(nudDisp.c) == 0 { - runtime.Gosched() - } - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for delay NUD event: %s", err) - } - if test.isHost1Listener { - // If host1 is not the client, host1 does not send any data so TCP - // has no way to know it is making forward progress. Because of this, - // TCP should not mark the route reachable and NUD should go through the - // probe state. - clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for probe NUD event: %s", err) - } - } - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := peerEP.Write(&r, wOpts); err != nil { - t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err) - } - } - <-clientCH - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - }) - } -} - -func TestDAD(t *testing.T) { - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - dadNetProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedResult stack.DADResult - }{ - { - name: "IPv4 own address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - { - name: "IPv6 own address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - { - name: "IPv4 duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, - }, - { - name: "IPv6 duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, - }, - { - name: "IPv4 no duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - { - name: "IPv6 no duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ - arp.NewProtocol, - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - } - - host1Stack, _ := setupStack(t, stackOpts, utils.Host1NICID, utils.Host2NICID) - - // DAD should be disabled by default. - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled") - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADDisabled { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled) - } - - // Enable DAD then attempt to check if an address is duplicated. - netEP, err := host1Stack.GetNetworkEndpoint(utils.Host1NICID, test.dadNetProto) - if err != nil { - t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", utils.Host1NICID, test.dadNetProto, err) - } - dad, ok := netEP.(stack.DuplicateAddressDetector) - if !ok { - t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP) - } - dad.SetDADConfigurations(dadConfigs) - ch := make(chan stack.DADResult, 3) - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADStarting { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting) - } - - expectResults := 1 - if _, ok := test.expectedResult.(*stack.DADSucceeded); ok { - const delta = time.Nanosecond - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r) - default: - } - - // If we expect the resolve to succeed try requesting DAD again on the - // same address. The handler for the new request should be called once - // the original DAD request completes. - expectResults = 2 - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADAlreadyRunning { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning) - } - - clock.Advance(delta) - } - - for i := 0; i < expectResults; i++ { - if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" { - t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) - } - } - - // Should have no more results. - select { - case r := <-ch: - t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go deleted file mode 100644 index b2008f0b2..000000000 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ /dev/null @@ -1,763 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package loopback_test - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) - -type ndpDispatcher struct{} - -func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { -} - -func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) { -} - -func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {} - -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { -} - -func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} - -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} - -func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {} - -func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {} - -func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {} - -func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {} - -// TestInitialLoopbackAddresses tests that the loopback interface does not -// auto-generate a link-local address when it is brought up. -func TestInitialLoopbackAddresses(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDispatcher{}, - AutoGenLinkLocal: true, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID, nicName string) string { - t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) - return "" - }, - }, - })}, - }) - - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - nicsInfo := s.NICInfo() - if nicInfo, ok := nicsInfo[nicID]; !ok { - t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo) - } else if got := len(nicInfo.ProtocolAddresses); got != 0 { - t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses) - } -} - -// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address and UDP -// traffic is sent/received correctly. -func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { - const ( - nicID = 1 - localPort = 80 - ) - - data := []byte{1, 2, 3, 4} - - ipv4ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) - ipv4Bytes[len(ipv4Bytes)-1]++ - otherIPv4Address := tcpip.Address(ipv4Bytes) - - ipv6ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - ipv6Bytes := []byte(utils.Ipv6Addr.Address) - ipv6Bytes[len(ipv6Bytes)-1]++ - otherIPv6Address := tcpip.Address(ipv6Bytes) - - tests := []struct { - name string - addAddress tcpip.ProtocolAddress - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool - }{ - { - name: "IPv4 bind to wildcard and send to assigned address", - addAddress: ipv4ProtocolAddress, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectRx: true, - }, - { - name: "IPv4 bind to wildcard and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - dstAddr: otherIPv4Address, - expectRx: true, - }, - { - name: "IPv4 bind to wildcard send to other address", - addAddress: ipv4ProtocolAddress, - dstAddr: utils.RemoteIPv4Addr, - expectRx: false, - }, - { - name: "IPv4 bind to other subnet-local address and send to assigned address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectRx: false, - }, - { - name: "IPv4 bind and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: otherIPv4Address, - expectRx: true, - }, - { - name: "IPv4 bind to assigned address and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - dstAddr: otherIPv4Address, - expectRx: false, - }, - - { - name: "IPv6 bind and send to assigned address", - addAddress: ipv6ProtocolAddress, - bindAddr: utils.Ipv6Addr.Address, - dstAddr: utils.Ipv6Addr.Address, - expectRx: true, - }, - { - name: "IPv6 bind to wildcard and send to other subnet-local address", - addAddress: ipv6ProtocolAddress, - dstAddr: otherIPv6Address, - expectRx: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - var wq waiter.Queue - rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer rep.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} - if err := rep.Bind(bindAddr); err != nil { - t.Fatalf("rep.Bind(%+v): %s", bindAddr, err) - } - - sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer sep.Close() - - wopts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{ - Addr: test.dstAddr, - Port: localPort, - }, - } - var r bytes.Reader - r.Reset(data) - n, err := sep.Write(&r, wopts) - if err != nil { - t.Fatalf("sep.Write(_, _): %s", err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) - } - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - if res, err := rep.Read(&buf, opts); test.expectRx { - if err != nil { - t.Fatalf("rep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{ - Addr: test.addAddress.AddressWithPrefix.Address, - }, - }, res, - checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"), - ); diff != "" { - t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address -// in a loopback interface's associated subnet is bound to the permanently bound -// address. -func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { - const nicID = 1 - - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - addrBytes := []byte(utils.Ipv4Addr.Address) - addrBytes[len(addrBytes)-1]++ - otherAddr := tcpip.Address(addrBytes) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - r, err := s.FindRoute(nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, err) - } - defer r.Release() - - params := stack.NetworkHeaderParams{ - Protocol: 111, - TTL: 64, - TOS: stack.DefaultTOS, - } - data := buffer.View([]byte{1, 2, 3, 4}) - if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != nil { - t.Fatalf("r.WritePacket(%#v, _): %s", params, err) - } - - // Removing the address should make the endpoint invalid. - if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) - } - { - err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })) - if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) - } - } -} - -// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address and TCP -// traffic is sent/received correctly. -func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { - const ( - nicID = 1 - localPort = 80 - ) - - ipv4ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8 - ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) - ipv4Bytes[len(ipv4Bytes)-1]++ - otherIPv4Address := tcpip.Address(ipv4Bytes) - - ipv6ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - ipv6Bytes := []byte(utils.Ipv6Addr.Address) - ipv6Bytes[len(ipv6Bytes)-1]++ - otherIPv6Address := tcpip.Address(ipv6Bytes) - - tests := []struct { - name string - addAddress tcpip.ProtocolAddress - bindAddr tcpip.Address - dstAddr tcpip.Address - expectAccept bool - }{ - { - name: "IPv4 bind to wildcard and send to assigned address", - addAddress: ipv4ProtocolAddress, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectAccept: true, - }, - { - name: "IPv4 bind to wildcard and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - dstAddr: otherIPv4Address, - expectAccept: true, - }, - { - name: "IPv4 bind to wildcard send to other address", - addAddress: ipv4ProtocolAddress, - dstAddr: utils.RemoteIPv4Addr, - expectAccept: false, - }, - { - name: "IPv4 bind to other subnet-local address and send to assigned address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectAccept: false, - }, - { - name: "IPv4 bind and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: otherIPv4Address, - expectAccept: true, - }, - { - name: "IPv4 bind to assigned address and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - dstAddr: otherIPv4Address, - expectAccept: false, - }, - - { - name: "IPv6 bind and send to assigned address", - addAddress: ipv6ProtocolAddress, - bindAddr: utils.Ipv6Addr.Address, - dstAddr: utils.Ipv6Addr.Address, - expectAccept: true, - }, - { - name: "IPv6 bind to wildcard and send to other subnet-local address", - addAddress: ipv6ProtocolAddress, - dstAddr: otherIPv6Address, - expectAccept: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer listeningEndpoint.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} - if err := listeningEndpoint.Bind(bindAddr); err != nil { - t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err) - } - - if err := listeningEndpoint.Listen(1); err != nil { - t.Fatalf("listeningEndpoint.Listen(1): %s", err) - } - - connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer connectingEndpoint.Close() - - connectAddr := tcpip.FullAddress{ - Addr: test.dstAddr, - Port: localPort, - } - { - err := connectingEndpoint.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) - } - } - - if !test.expectAccept { - _, _, err := listeningEndpoint.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - return - } - - // Wait for the listening endpoint to be "readable". That is, wait for a - // new connection. - <-ch - var addr tcpip.FullAddress - if _, _, err := listeningEndpoint.Accept(&addr); err != nil { - t.Fatalf("listeningEndpoint.Accept(nil): %s", err) - } - if addr.Addr != test.addAddress.AddressWithPrefix.Address { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) - } - }) - } -} - -func TestExternalLoopbackTraffic(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - numPackets = 1 - ttl = 64 - ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") - - loopbackSourcedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl) - } - - loopbackSourcedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl) - } - - loopbackDestinedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl) - } - - loopbackDestinedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl) - } - - invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { - return s.InvalidSourceAddressesReceived - } - - invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { - return s.InvalidDestinationAddressesReceived - } - - tests := []struct { - name string - allowExternalLoopback bool - forwarding bool - rxICMP func(*channel.Endpoint) - invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter - shouldAccept bool - }{ - { - name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: true, - }, - { - name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - - { - name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: true, - }, - { - name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - AllowExternalLoopbackTraffic: test.allowExternalLoopback, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - AllowExternalLoopbackTraffic: test.allowExternalLoopback, - }), - }, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) - } - if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) - } - - if err := s.CreateNIC(nicID2, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) - } - if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) - } - - if test.forwarding { - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - } - - s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }, - tcpip.Route{ - Destination: header.IPv6EmptySubnet, - NIC: nicID1, - }, - tcpip.Route{ - Destination: ipv4Loopback.WithPrefix().Subnet(), - NIC: nicID2, - }, - tcpip.Route{ - Destination: header.IPv6Loopback.WithPrefix().Subnet(), - NIC: nicID2, - }, - }) - - stats := s.Stats().IP - invalidAddressStat := test.invalidAddressStat(stats) - deliveredPacketsStat := stats.PacketsDelivered - if got := invalidAddressStat.Value(); got != 0 { - t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got) - } - if got := deliveredPacketsStat.Value(); got != 0 { - t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got) - } - test.rxICMP(e) - var expectedInvalidPackets uint64 - if !test.shouldAccept { - expectedInvalidPackets = numPackets - } - if got := invalidAddressStat.Value(); got != expectedInvalidPackets { - t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets) - } - if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want { - t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go deleted file mode 100644 index 2d0a6e6a7..000000000 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ /dev/null @@ -1,723 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package multicast_broadcast_test - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - defaultMTU = 1280 - ttl = 255 -) - -// TestPingMulticastBroadcast tests that responding to an Echo Request destined -// to a multicast or broadcast address uses a unicast source address for the -// reply. -func TestPingMulticastBroadcast(t *testing.T) { - const ( - nicID = 1 - ttl = 64 - ) - - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8) - srcAddr tcpip.Address - dstAddr tcpip.Address - expectedSrc tcpip.Address - }{ - { - name: "IPv4 unicast", - protoNum: header.IPv4ProtocolNumber, - dstAddr: utils.Ipv4Addr.Address, - srcAddr: utils.RemoteIPv4Addr, - rxICMP: utils.RxICMPv4EchoRequest, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 directed broadcast", - protoNum: header.IPv4ProtocolNumber, - rxICMP: utils.RxICMPv4EchoRequest, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4SubnetBcast, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 broadcast", - protoNum: header.IPv4ProtocolNumber, - rxICMP: utils.RxICMPv4EchoRequest, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: header.IPv4Broadcast, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 all-systems multicast", - protoNum: header.IPv4ProtocolNumber, - rxICMP: utils.RxICMPv4EchoRequest, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: header.IPv4AllSystems, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv6 unicast", - protoNum: header.IPv6ProtocolNumber, - rxICMP: utils.RxICMPv6EchoRequest, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr.Address, - expectedSrc: utils.Ipv6Addr.Address, - }, - { - name: "IPv6 all-nodes multicast", - protoNum: header.IPv6ProtocolNumber, - rxICMP: utils.RxICMPv6EchoRequest, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: header.IPv6AllNodesMulticastAddress, - expectedSrc: utils.Ipv6Addr.Address, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - }) - // We only expect a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) - } - ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) - } - - // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - test.rxICMP(e, test.srcAddr, test.dstAddr, ttl) - pkt, ok := e.Read() - if !ok { - t.Fatal("expected ICMP response") - } - - if pkt.Route.LocalAddress != test.expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc) - } - // The destination of the response packet should be the source of the - // original packet. - if pkt.Route.RemoteAddress != test.srcAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr) - } - - src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != test.expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc) - } - // The destination of the response packet should be the source of the - // original packet. - if dst != test.srcAddr { - t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr) - } - }) - } - -} - -func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { - payloadLen := header.UDPMinimumSize + len(data) - totalLen := header.IPv4MinimumSize + payloadLen - hdr := buffer.NewPrependable(totalLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: utils.RemotePort, - DstPort: utils.LocalPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(udp.ProtocolNumber), - TTL: ttl, - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { - payloadLen := header.UDPMinimumSize + len(data) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: utils.RemotePort, - DstPort: utils.LocalPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - TransportProtocol: udp.ProtocolNumber, - HopLimit: ttl, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some -// multicast or broadcast address. -func TestIncomingMulticastAndBroadcast(t *testing.T) { - const nicID = 1 - - data := []byte{1, 2, 3, 4} - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - localAddr tcpip.AddressWithPrefix - rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool - }{ - { - name: "IPv4 unicast binding to unicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4Addr.Address, - dstAddr: utils.Ipv4Addr.Address, - expectRx: true, - }, - { - name: "IPv4 unicast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: utils.Ipv4Addr.Address, - expectRx: false, - }, - { - name: "IPv4 unicast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4Addr.Address, - expectRx: true, - }, - - { - name: "IPv4 directed broadcast binding to subnet broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4SubnetBcast, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - { - name: "IPv4 directed broadcast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: false, - }, - { - name: "IPv4 directed broadcast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - - { - name: "IPv4 broadcast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: header.IPv4Broadcast, - expectRx: true, - }, - { - name: "IPv4 broadcast binding to subnet broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4SubnetBcast, - dstAddr: header.IPv4Broadcast, - expectRx: false, - }, - { - name: "IPv4 broadcast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - - { - name: "IPv4 all-systems multicast binding to all-systems multicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4AllSystems, - dstAddr: header.IPv4AllSystems, - expectRx: true, - }, - { - name: "IPv4 all-systems multicast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: header.IPv4AllSystems, - expectRx: true, - }, - { - name: "IPv4 all-systems multicast binding to unicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4Addr.Address, - dstAddr: header.IPv4AllSystems, - expectRx: false, - }, - - // IPv6 has no notion of a broadcast. - { - name: "IPv6 unicast binding to wildcard", - dstAddr: utils.Ipv6Addr.Address, - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - expectRx: true, - }, - { - name: "IPv6 broadcast-like address binding to wildcard", - dstAddr: utils.Ipv6SubnetBcast, - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - expectRx: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) - } - - test.rxUDP(e, test.remoteAddr, test.dstAddr, data) - var buf bytes.Buffer - var opts tcpip.ReadOptions - if res, err := ep.Read(&buf, opts); test.expectRx { - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all -// interested endpoints. -func TestReuseAddrAndBroadcast(t *testing.T) { - const ( - nicID = 1 - localPort = 9000 - ) - loopbackBroadcast := testutil.MustParse4("127.255.255.255") - - tests := []struct { - name string - broadcastAddr tcpip.Address - }{ - { - name: "Subnet directed broadcast", - broadcastAddr: loopbackBroadcast, - }, - { - name: "IPv4 broadcast", - broadcastAddr: header.IPv4Broadcast, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: "\x7f\x00\x00\x01", - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - // We use the empty subnet instead of just the loopback subnet so we - // also have a route to the IPv4 Broadcast address. - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - type endpointAndWaiter struct { - ep tcpip.Endpoint - ch chan struct{} - } - var eps []endpointAndWaiter - // We create endpoints that bind to both the wildcard address and the - // broadcast address to make sure both of these types of "broadcast - // interested" endpoints receive broadcast packets. - for _, bindWildcard := range []bool{false, true} { - // Create multiple endpoints for each type of "broadcast interested" - // endpoint so we can test that all endpoints receive the broadcast - // packet. - for i := 0; i < 2; i++ { - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - defer ep.Close() - - ep.SocketOptions().SetReuseAddress(true) - ep.SocketOptions().SetBroadcast(true) - - bindAddr := tcpip.FullAddress{Port: localPort} - if bindWildcard { - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) - } - } else { - bindAddr.Addr = test.broadcastAddr - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) - } - } - - eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) - } - } - - for i, wep := range eps { - writeOpts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{ - Addr: test.broadcastAddr, - Port: localPort, - }, - } - data := []byte{byte(i), 2, 3, 4} - var r bytes.Reader - r.Reset(data) - if n, err := wep.ep.Write(&r, writeOpts); err != nil { - t.Fatalf("eps[%d].Write(_, _): %s", i, err) - } else if want := int64(len(data)); n != want { - t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) - } - - for j, rep := range eps { - // Wait for the endpoint to become readable. - <-rep.ch - - var buf bytes.Buffer - result, err := rep.ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) - continue - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff) - } - if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" { - t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) - } - } - } - }) - } -} - -func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { - const ( - nicID = 1 - ) - - data := []byte{1, 2, 3, 4} - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - localAddr tcpip.AddressWithPrefix - rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) - multicastAddr tcpip.Address - }{ - { - name: "IPv4 unicast binding to unicast", - multicastAddr: "\xe0\x01\x02\x03", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - }, - { - name: "IPv6 broadcast-like address binding to wildcard", - multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04", - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - }, - } - - subTests := []struct { - name string - specifyNICID bool - specifyNICAddr bool - }{ - { - name: "Specify NIC ID and NIC address", - specifyNICID: true, - specifyNICAddr: true, - }, - { - name: "Don't specify NIC ID or NIC address", - specifyNICID: false, - specifyNICAddr: false, - }, - { - name: "Specify NIC ID but don't specify NIC address", - specifyNICID: true, - specifyNICAddr: false, - }, - { - name: "Don't specify NIC ID but specify NIC address", - specifyNICID: false, - specifyNICAddr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) - } - - // Set the route table so that UDP can find a NIC that is - // routable to the multicast address when the NIC isn't specified. - if !subTest.specifyNICID && !subTest.specifyNICAddr { - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Port: utils.LocalPort} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) - } - - memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr} - if subTest.specifyNICID { - memOpt.NIC = nicID - } - if subTest.specifyNICAddr { - memOpt.InterfaceAddr = test.localAddr.Address - } - - // We should receive UDP packets to the group once we join the - // multicast group. - addOpt := tcpip.AddMembershipOption(memOpt) - if err := ep.SetSockOpt(&addOpt); err != nil { - t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) - } - test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("ep.Read: %s", err) - } else { - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } - - // We should not receive UDP packets to the group once we leave - // the multicast group. - removeOpt := tcpip.RemoveMembershipOption(memOpt) - if err := ep.SetSockOpt(&removeOpt); err != nil { - t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) - } - { - _, err := ep.Read(&buf, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go deleted file mode 100644 index ac3c703d4..000000000 --- a/pkg/tcpip/tests/integration/route_test.go +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package route_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TestLocalPing tests pinging a remote that is local the stack. -// -// This tests that a local route is created and packets do not leave the stack. -func TestLocalPing(t *testing.T) { - const ( - nicID = 1 - - // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo - // request/reply packets. - icmpDataOffset = 8 - ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") - - channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } - channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { - channelEP := e.(*channel.Endpoint) - if n := channelEP.Drain(); n != 0 { - t.Fatalf("got channelEP.Drain() = %d, want = 0", n) - } - } - - ipv4ICMPBuf := func(t *testing.T) buffer.View { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) - hdr.SetType(header.ICMPv4Echo) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return buffer.View(hdr) - } - - ipv6ICMPBuf := func(t *testing.T) buffer.View { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} - hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) - hdr.SetType(header.ICMPv6EchoRequest) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return buffer.View(hdr) - } - - tests := []struct { - name string - transProto tcpip.TransportProtocolNumber - netProto tcpip.NetworkProtocolNumber - linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.Address - icmpBuf func(*testing.T) buffer.View - expectedConnectErr tcpip.Error - checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) - }{ - { - name: "IPv4 loopback", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: loopback.New, - localAddr: ipv4Loopback, - icmpBuf: ipv4ICMPBuf, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv6 loopback", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback, - icmpBuf: ipv6ICMPBuf, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv4 non-loopback", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr.Address, - icmpBuf: ipv4ICMPBuf, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv6 non-loopback", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr.Address, - icmpBuf: ipv6ICMPBuf, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv4 loopback without local address", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: loopback.New, - icmpBuf: ipv4ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv6 loopback without local address", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: loopback.New, - icmpBuf: ipv6ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv4 non-loopback without local address", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: channelEP, - icmpBuf: ipv4ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv6 non-loopback without local address", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: channelEP, - icmpBuf: ipv6ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: channelEPCheck, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, allowExternalLoopback := range []bool{true, false} { - t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - AllowExternalLoopbackTraffic: allowExternalLoopback, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - AllowExternalLoopbackTraffic: allowExternalLoopback, - }), - }, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) - } - } - - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) - } - - if test.expectedConnectErr != nil { - return - } - - var r bytes.Reader - payload := test.icmpBuf(t) - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) - } - - // Wait for the endpoint to become readable. - <-ch - - var w bytes.Buffer - rr, err := ep.Read(&w, tcpip.ReadOptions{ - NeedRemoteAddr: true, - }) - if err != nil { - t.Fatalf("ep.Read(...): %s", err) - } - if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - if rr.RemoteAddr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) - } - - test.checkLinkEndpoint(t, e) - }) - } - }) - } -} - -// TestLocalUDP tests sending UDP packets between two endpoints that are local -// to the stack. -// -// This tests that that packets never leave the stack and the addresses -// used when sending a packet. -func TestLocalUDP(t *testing.T) { - const ( - nicID = 1 - ) - - tests := []struct { - name string - canBePrimaryAddr tcpip.ProtocolAddress - firstPrimaryAddr tcpip.ProtocolAddress - }{ - { - name: "IPv4", - canBePrimaryAddr: utils.Ipv4Addr1, - firstPrimaryAddr: utils.Ipv4Addr2, - }, - { - name: "IPv6", - canBePrimaryAddr: utils.Ipv6Addr1, - firstPrimaryAddr: utils.Ipv6Addr2, - }, - } - - subTests := []struct { - name string - addAddress bool - expectedWriteErr tcpip.Error - }{ - { - name: "Unassigned local address", - addAddress: false, - expectedWriteErr: &tcpip.ErrNoRoute{}, - }, - { - name: "Assigned local address", - addAddress: true, - expectedWriteErr: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: true, - } - - s := stack.New(stackOpts) - ep := channel.New(1, header.IPv6MinimumMTU, "") - - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if subTest.addAddress { - if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) - } - if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) - } - } - - var serverWQ waiter.Queue - serverWE, serverCH := waiter.NewChannelEntry(nil) - serverWQ.EventRegister(&serverWE, waiter.ReadableEvents) - server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) - } - defer server.Close() - - bindAddr := tcpip.FullAddress{Port: 80} - if err := server.Bind(bindAddr); err != nil { - t.Fatalf("server.Bind(%#v): %s", bindAddr, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents) - client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) - } - defer client.Close() - - serverAddr := tcpip.FullAddress{ - Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, - Port: 80, - } - - clientPayload := []byte{1, 2, 3, 4} - { - var r bytes.Reader - r.Reset(clientPayload) - wOpts := tcpip.WriteOptions{ - To: &serverAddr, - } - if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { - t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) - } else if subTest.expectedWriteErr != nil { - // Nothing else to test if we expected not to be able to send the - // UDP packet. - return - } else if n != int64(len(clientPayload)) { - t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload)) - } - } - - // Wait for the server endpoint to become readable. - <-serverCH - - var clientAddr tcpip.FullAddress - var readBuf bytes.Buffer - if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { - t.Fatalf("server.Read(_): %s", err) - } else { - clientAddr = read.RemoteAddr - - if diff := cmp.Diff(tcpip.ReadResult{ - Count: readBuf.Len(), - Total: readBuf.Len(), - RemoteAddr: tcpip.FullAddress{ - Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, - }, - }, read, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" { - t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - } - - serverPayload := []byte{1, 2, 3, 4} - { - var r bytes.Reader - r.Reset(serverPayload) - wOpts := tcpip.WriteOptions{ - To: &clientAddr, - } - if n, err := server.Write(&r, wOpts); err != nil { - t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) - } else if n != int64(len(serverPayload)) { - t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) - } - } - - // Wait for the client endpoint to become readable. - <-clientCH - - readBuf.Reset() - if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { - t.Fatalf("client.Read(_): %s", err) - } else { - if diff := cmp.Diff(tcpip.ReadResult{ - Count: readBuf.Len(), - Total: readBuf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr}, - }, read, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" { - t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD deleted file mode 100644 index a9699a367..000000000 --- a/pkg/tcpip/tests/utils/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "utils", - srcs = ["utils.go"], - visibility = ["//pkg/tcpip/tests:__subpackages__"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/ethernet", - "//pkg/tcpip/link/nested", - "//pkg/tcpip/link/pipe", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - ], -) diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go deleted file mode 100644 index 2e6ae55ea..000000000 --- a/pkg/tcpip/tests/utils/utils.go +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package utils holds common testing utilities for tcpip. -package utils - -import ( - "net" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" - "gvisor.dev/gvisor/pkg/tcpip/link/nested" - "gvisor.dev/gvisor/pkg/tcpip/link/pipe" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -// Common NIC IDs used by tests. -const ( - Host1NICID = 1 - RouterNICID1 = 2 - RouterNICID2 = 3 - Host2NICID = 4 -) - -// Common link addresses used by tests. -const ( - LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - LinkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - LinkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") -) - -// Common IP addresses used by tests. -var ( - Ipv4Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - } - Ipv4Subnet = Ipv4Addr.Subnet() - Ipv4SubnetBcast = Ipv4Subnet.Broadcast() - - Ipv6Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("200a::1").To16()), - PrefixLen: 64, - } - Ipv6Subnet = Ipv6Addr.Subnet() - Ipv6SubnetBcast = Ipv6Subnet.Broadcast() - - Ipv4Addr1 = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - Ipv4Addr2 = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 8, - }, - } - Ipv4Addr3 = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.3").To4()), - PrefixLen: 8, - }, - } - Ipv6Addr1 = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - Ipv6Addr2 = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - Ipv6Addr3 = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::3").To16()), - PrefixLen: 64, - }, - } - - // Remote addrs. - RemoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) - RemoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16()) -) - -// Common ports for testing. -const ( - RemotePort = 5555 - LocalPort = 80 -) - -// Common IP addresses used for testing. -var ( - Host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 24, - }, - } - RouterNIC1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - RouterNIC2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - }, - } - Host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), - PrefixLen: 8, - }, - } - Host1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - RouterNIC1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - RouterNIC2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("b::1").To16()), - PrefixLen: 64, - }, - } - Host2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("b::2").To16()), - PrefixLen: 64, - }, - } -) - -// NewEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func NewEthernetEndpoint(ep stack.LinkEndpoint) *EndpointWithDestinationCheck { - var e EndpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// EndpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type EndpointWithDestinationCheck struct { - nested.Endpoint -} - -var _ stack.NetworkDispatcher = (*EndpointWithDestinationCheck)(nil) -var _ stack.LinkEndpoint = (*EndpointWithDestinationCheck)(nil) - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *EndpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -// SetupRoutedStacks creates the NICs, sets forwarding, adds addresses and sets -// the route tables for the passed stacks. -func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { - host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) - routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) - } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) - } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) - } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) - } - - if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: Host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: Host1NICID, - }, - { - Destination: Host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: Host1NICID, - }, - { - Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: Host1NICID, - }, - { - Destination: Host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: Host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: RouterNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID1, - }, - { - Destination: RouterNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID1, - }, - { - Destination: RouterNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID2, - }, - { - Destination: RouterNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: RouterNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: Host2NICID, - }, - { - Destination: Host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: Host2NICID, - }, - { - Destination: Host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: Host2NICID, - }, - { - Destination: Host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: RouterNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: Host2NICID, - }, - }) -} - -func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(ty) - pkt.SetCode(header.ICMPv4UnusedCode) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ttl, - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on -// the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { - rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo) -} - -// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on -// the provided endpoint. -func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { - rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply) -} - -func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(ty) - pkt.SetCode(header.ICMPv6UnusedCode) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: src, - Dst: dst, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ttl, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on -// the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { - rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest) -} - -// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on -// the provided endpoint. -func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { - rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply) -} diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD deleted file mode 100644 index 02ee86ff1..000000000 --- a/pkg/tcpip/testutil/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - testonly = True, - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = ["//pkg/tcpip"], -) - -go_test( - name = "testutil_test", - srcs = ["testutil_test.go"], - library = ":testutil", - deps = ["//pkg/tcpip"], -) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go deleted file mode 100644 index 94b580a70..000000000 --- a/pkg/tcpip/testutil/testutil.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides helper functions for netstack unit tests. -package testutil - -import ( - "fmt" - "net" - "reflect" - "strings" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address. -// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes. -func MustParse4(addr string) tcpip.Address { - ip := net.ParseIP(addr).To4() - if ip == nil { - panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr)) - } - return tcpip.Address(ip) -} - -// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing -// an IPv4 address will yield an IPv4-mapped IPv6 address. -func MustParse6(addr string) tcpip.Address { - ip := net.ParseIP(addr).To16() - if ip == nil { - panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr)) - } - return tcpip.Address(ip) -} - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} - -// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on -// error. -// -// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. -func MustParseLink(addr string) tcpip.LinkAddress { - parsed, err := tcpip.ParseMACAddress(addr) - if err != nil { - panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) - } - return parsed -} diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go deleted file mode 100644 index 6aad9585d..000000000 --- a/pkg/tcpip/testutil/testutil_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -// Who tests the testutils? - -func TestMustParse4(t *testing.T) { - tcs := []struct { - str string - addr tcpip.Address - shouldPanic bool - }{ - { - str: "127.0.0.1", - addr: "\x7f\x00\x00\x01", - }, { - str: "", - shouldPanic: true, - }, { - str: "fe80::1", - shouldPanic: true, - }, { - // In an ideal world this panics too, but net.IP - // doesn't distinguish between IPv4 and IPv4-mapped - // addresses. - str: "::ffff:0.0.0.1", - addr: "\x00\x00\x00\x01", - }, - } - - for _, tc := range tcs { - t.Run(tc.str, func(t *testing.T) { - if tc.shouldPanic { - defer func() { - if r := recover(); r == nil { - t.Errorf("panic expected, but did not occur") - } - }() - } - if got := MustParse4(tc.str); got != tc.addr { - t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr) - } - }) - } -} - -func TestMustParse6(t *testing.T) { - tcs := []struct { - str string - addr tcpip.Address - shouldPanic bool - }{ - { - // In an ideal world this panics too, but net.IP - // doesn't distinguish between IPv4 and IPv4-mapped - // addresses. - str: "127.0.0.1", - addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01", - }, { - str: "", - shouldPanic: true, - }, { - str: "fe80::1", - addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - }, { - str: "::ffff:0.0.0.1", - addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01", - }, - } - - for _, tc := range tcs { - t.Run(tc.str, func(t *testing.T) { - if tc.shouldPanic { - defer func() { - if r := recover(); r == nil { - t.Errorf("panic expected, but did not occur") - } - }() - } - if got := MustParse6(tc.str); got != tc.addr { - t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr) - } - }) - } -} diff --git a/pkg/tcpip/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go deleted file mode 100644 index 5ff764800..000000000 --- a/pkg/tcpip/testutil/testutil_unsafe.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "reflect" - "unsafe" -) - -// unsafeExposeUnexportedFields takes a Value and returns a version of it in -// which even unexported fields can be read and written. -func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { - return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() -} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go deleted file mode 100644 index ed1ed8ac6..000000000 --- a/pkg/tcpip/timer_test.go +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip_test - -import ( - "math" - "sync" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -func TestMonotonicTimeBefore(t *testing.T) { - var mt tcpip.MonotonicTime - if mt.Before(mt) { - t.Errorf("%#v.Before(%#v)", mt, mt) - } - - one := mt.Add(1) - if one.Before(mt) { - t.Errorf("%#v.Before(%#v)", one, mt) - } - if !mt.Before(one) { - t.Errorf("!%#v.Before(%#v)", mt, one) - } -} - -func TestMonotonicTimeAfter(t *testing.T) { - var mt tcpip.MonotonicTime - if mt.After(mt) { - t.Errorf("%#v.After(%#v)", mt, mt) - } - - one := mt.Add(1) - if mt.After(one) { - t.Errorf("%#v.After(%#v)", mt, one) - } - if !one.After(mt) { - t.Errorf("!%#v.After(%#v)", one, mt) - } -} - -func TestMonotonicTimeAddSub(t *testing.T) { - var mt tcpip.MonotonicTime - if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { - t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) - } - - min := mt.Add(math.MinInt64) - max := mt.Add(math.MaxInt64) - - if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max { - t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) - } - if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min { - t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow) - } - - if got, want := min.Sub(min), time.Duration(0); want != got { - t.Errorf("got min.Sub(min) = %d, want %d", got, want) - } - if got, want := max.Sub(max), time.Duration(0); want != got { - t.Errorf("got max.Sub(max) = %d, want %d", got, want) - } - - if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want { - t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow) - } - if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want { - t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow) - } -} - -func TestMonotonicTimeSub(t *testing.T) { - var mt tcpip.MonotonicTime - - if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { - t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) - } - - if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow { - t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) - } - if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow { - t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow) - } -} - -const ( - shortDuration = 1 * time.Nanosecond - middleDuration = 100 * time.Millisecond -) - -func TestJobReschedule(t *testing.T) { - clock := tcpip.NewStdClock() - var wg sync.WaitGroup - var lock sync.Mutex - - for i := 0; i < 2; i++ { - wg.Add(1) - - go func() { - lock.Lock() - // Assigning a new timer value updates the timer's locker and function. - // This test makes sure there is no data race when reassigning a timer - // that has an active timer (even if it has been stopped as a stopped - // timer may be blocked on a lock before it can check if it has been - // stopped while another goroutine holds the same lock). - job := tcpip.NewJob(clock, &lock, func() { - wg.Done() - }) - job.Schedule(shortDuration) - lock.Unlock() - }() - } - wg.Wait() -} - -func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) { - return tcpip.NewStdClock(), time.After -} - -func TestJobExecution(t *testing.T) { - t.Parallel() - - clock, after := stdClockWithAfter() - var lock sync.Mutex - ch := make(chan struct{}) - - job := tcpip.NewJob(clock, &lock, func() { - ch <- struct{}{} - }) - job.Schedule(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-after(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-after(middleDuration): - } -} - -func TestCancellableTimerResetFromLongDuration(t *testing.T) { - t.Parallel() - - clock, after := stdClockWithAfter() - var lock sync.Mutex - ch := make(chan struct{}) - - job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(middleDuration) - - lock.Lock() - job.Cancel() - lock.Unlock() - - job.Schedule(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-after(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-after(middleDuration): - } -} - -func TestJobRescheduleFromShortDuration(t *testing.T) { - t.Parallel() - - clock, after := stdClockWithAfter() - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - job.Cancel() - lock.Unlock() - - // Wait for timer to fire if it wasn't correctly stopped. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-after(middleDuration): - } - - job.Schedule(shortDuration) - - // Wait for timer to fire. - select { - case <-ch: - case <-after(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-after(middleDuration): - } -} - -func TestJobImmediatelyCancel(t *testing.T) { - t.Parallel() - - clock, after := stdClockWithAfter() - var lock sync.Mutex - ch := make(chan struct{}) - - for i := 0; i < 1000; i++ { - lock.Lock() - job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - job.Cancel() - lock.Unlock() - } - - // Wait for timer to fire if it wasn't correctly stopped. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-after(middleDuration): - } -} - -func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) { - clock, after := stdClockWithAfter() - return clock, after, time.Sleep -} - -func TestJobCancelledRescheduleWithoutLock(t *testing.T) { - t.Parallel() - - clock, after, sleep := stdClockWithAfterAndSleep() - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - job.Cancel() - lock.Unlock() - - for i := 0; i < 10; i++ { - job.Schedule(middleDuration) - - lock.Lock() - // Sleep until the timer fires and gets blocked trying to take the lock. - sleep(middleDuration * 2) - job.Cancel() - lock.Unlock() - } - - // Wait for double the duration so timers that weren't correctly stopped can - // fire. - select { - case <-ch: - t.Fatal("timer fired after being stopped") - case <-after(middleDuration * 2): - } -} - -func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { - t.Parallel() - - clock, after, sleep := stdClockWithAfterAndSleep() - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - for i := 0; i < 10; i++ { - // Sleep until the timer fires and gets blocked trying to take the lock. - sleep(middleDuration) - job.Cancel() - job.Schedule(shortDuration) - } - lock.Unlock() - - // Wait for double the duration for the last timer to fire. - select { - case <-ch: - case <-after(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-after(middleDuration): - } -} - -func TestManyJobReschedulesUnderLock(t *testing.T) { - t.Parallel() - - clock, after := stdClockWithAfter() - var lock sync.Mutex - ch := make(chan struct{}) - - lock.Lock() - job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) - job.Schedule(shortDuration) - for i := 0; i < 10; i++ { - job.Cancel() - job.Schedule(shortDuration) - } - lock.Unlock() - - // Wait for double the duration for the last timer to fire. - select { - case <-ch: - case <-after(middleDuration): - t.Fatal("timed out waiting for timer to fire") - } - - // The timer should have fired only once. - select { - case <-ch: - t.Fatal("no other timers should have fired") - case <-after(middleDuration): - } -} diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD deleted file mode 100644 index bbc0e3ecc..000000000 --- a/pkg/tcpip/transport/icmp/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "icmp_packet_list", - out = "icmp_packet_list.go", - package = "icmp", - prefix = "icmpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*icmpPacket", - "Linker": "*icmpPacket", - }, -) - -go_library( - name = "icmp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "icmp_packet_list.go", - "protocol.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) - -go_test( - name = "icmp_x_test", - size = "small", - srcs = ["icmp_test.go"], - deps = [ - ":icmp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100644 index 000000000..0aacdad3f --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,221 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *icmpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (icmpPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *icmpPacketList) PushFront(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *icmpPacketList) PushBack(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + bLinker := icmpPacketElementMapper{}.linkerFor(b) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + aLinker := icmpPacketElementMapper{}.linkerFor(a) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *icmpPacketList) Remove(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100644 index 000000000..d90b76d9c --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,170 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *icmpPacket) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacket" +} + +func (p *icmpPacket) StateFields() []string { + return []string{ + "icmpPacketEntry", + "senderAddress", + "data", + "receivedAt", + } +} + +func (p *icmpPacket) beforeSave() {} + +// +checklocksignore +func (p *icmpPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView + dataValue = p.saveData() + stateSinkObject.SaveValue(2, dataValue) + var receivedAtValue int64 + receivedAtValue = p.saveReceivedAt() + stateSinkObject.SaveValue(3, receivedAtValue) + stateSinkObject.Save(0, &p.icmpPacketEntry) + stateSinkObject.Save(1, &p.senderAddress) +} + +func (p *icmpPacket) afterLoad() {} + +// +checklocksignore +func (p *icmpPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.icmpPacketEntry) + stateSourceObject.Load(1, &p.senderAddress) + stateSourceObject.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(3, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/icmp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "rcvReady", + "rcvList", + "rcvBufSize", + "rcvClosed", + "shutdownFlags", + "state", + "ttl", + "owner", + "ops", + "frozen", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.rcvReady) + stateSinkObject.Save(5, &e.rcvList) + stateSinkObject.Save(6, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.shutdownFlags) + stateSinkObject.Save(9, &e.state) + stateSinkObject.Save(10, &e.ttl) + stateSinkObject.Save(11, &e.owner) + stateSinkObject.Save(12, &e.ops) + stateSinkObject.Save(13, &e.frozen) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.rcvReady) + stateSourceObject.Load(5, &e.rcvList) + stateSourceObject.Load(6, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.shutdownFlags) + stateSourceObject.Load(9, &e.state) + stateSourceObject.Load(10, &e.ttl) + stateSourceObject.Load(11, &e.owner) + stateSourceObject.Load(12, &e.ops) + stateSourceObject.Load(13, &e.frozen) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *icmpPacketList) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacketList" +} + +func (l *icmpPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *icmpPacketList) beforeSave() {} + +// +checklocksignore +func (l *icmpPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *icmpPacketList) afterLoad() {} + +// +checklocksignore +func (l *icmpPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *icmpPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacketEntry" +} + +func (e *icmpPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *icmpPacketEntry) beforeSave() {} + +// +checklocksignore +func (e *icmpPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *icmpPacketEntry) afterLoad() {} + +// +checklocksignore +func (e *icmpPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*icmpPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*icmpPacketList)(nil)) + state.Register((*icmpPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go deleted file mode 100644 index cc950cbde..000000000 --- a/pkg/tcpip/transport/icmp/icmp_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package icmp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package. -// See the issue for remaining areas of work. - -var ( - localV4Addr1 = testutil.MustParse4("10.0.0.1") - localV4Addr2 = testutil.MustParse4("10.0.0.2") - remoteV4Addr = testutil.MustParse4("10.0.0.3") -) - -func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint { - t.Helper() - - ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */) - t.Cleanup(ep.Close) - - wep := stack.LinkEndpoint(ep) - if testing.Verbose() { - wep = sniffer.New(ep) - } - - opts := stack.NICOptions{Name: name} - if err := s.CreateNICWithOptions(id, wep, opts); err != nil { - t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) - } - - if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) - } - - s.AddRoute(tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: id, - }) - - return ep -} - -func writePayload(buf []byte) { - for i := range buf { - buf[i] = byte(i) - } -} - -func newICMPv4EchoRequest(payloadSize uint32) buffer.View { - buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize)) - writePayload(buf[header.ICMPv4MinimumSize:]) - - icmp := header.ICMPv4(buf) - icmp.SetType(header.ICMPv4Echo) - // No need to set the checksum; it is reset by the socket before the packet - // is sent. - - return buf -} - -// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket -// when SO_BINDTODEVICE is set to the non-default NIC for that subnet. -// -// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to -// the version of IP. -func TestWriteUnboundWithBindToDevice(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - HandleLocal: true, - }) - - // Add two NICs, both with default routes on the same subnet. The first NIC - // added will be the default NIC for that subnet. - defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1) - alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2) - - socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err) - } - defer socket.Close() - - echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize - - // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC - // to be added is the default NIC to send packets when not explicitly bound. - { - buf := newICMPv4EchoRequest(echoPayloadSize) - r := buf.Reader() - n, err := socket.Write(&r, tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: remoteV4Addr}, - }) - if err != nil { - t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err) - } - if n != int64(len(buf)) { - t.Fatalf("got n = %d, want n = %d", n, len(buf)) - } - - // Verify the packet was sent out the default NIC. - p, ok := defaultEP.Read() - if !ok { - t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv4(t, b, []checker.NetworkChecker{ - checker.SrcAddr(localV4Addr1), - checker.DstAddr(remoteV4Addr), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), - ), - }...) - - // Verify the packet was not sent out the alternate NIC. - if p, ok := alternateEP.Read(); ok { - t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) - } - } - - // Send a packet with SO_BINDTODEVICE. This exercises reliance on - // SO_BINDTODEVICE to route the packet to the alternate NIC. - { - // Use SO_BINDTODEVICE to send over the alternate NIC by default. - socket.SocketOptions().SetBindToDevice(2) - - buf := newICMPv4EchoRequest(echoPayloadSize) - r := buf.Reader() - n, err := socket.Write(&r, tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: remoteV4Addr}, - }) - if err != nil { - t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) - } - if n != int64(len(buf)) { - t.Fatalf("got n = %d, want n = %d", n, len(buf)) - } - - // Verify the packet was not sent out the default NIC. - if p, ok := defaultEP.Read(); ok { - t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p) - } - - // Verify the packet was sent out the alternate NIC. - p, ok := alternateEP.Read() - if !ok { - t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)") - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv4(t, b, []checker.NetworkChecker{ - checker.SrcAddr(localV4Addr2), - checker.DstAddr(remoteV4Addr), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), - ), - }...) - } - - // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing - // the device binding will fallback to using the default NIC to send - // packets. - { - socket.SocketOptions().SetBindToDevice(0) - - buf := newICMPv4EchoRequest(echoPayloadSize) - r := buf.Reader() - n, err := socket.Write(&r, tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: remoteV4Addr}, - }) - if err != nil { - t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) - } - if n != int64(len(buf)) { - t.Fatalf("got n = %d, want n = %d", n, len(buf)) - } - - // Verify the packet was sent out the default NIC. - p, ok := defaultEP.Read() - if !ok { - t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv4(t, b, []checker.NetworkChecker{ - checker.SrcAddr(localV4Addr1), - checker.DstAddr(remoteV4Addr), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), - ), - }...) - - // Verify the packet was not sent out the alternate NIC. - if p, ok := alternateEP.Read(); ok { - t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) - } - } -} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD deleted file mode 100644 index b989b1209..000000000 --- a/pkg/tcpip/transport/packet/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "packet_list", - out = "packet_list.go", - package = "packet", - prefix = "packet", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*packet", - "Linker": "*packet", - }, -) - -go_library( - name = "packet", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go new file mode 100644 index 000000000..2c983aad0 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_list.go @@ -0,0 +1,221 @@ +package packet + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *packetList) Back() *packet { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *packetList) Len() (count int) { + for e := l.Front(); e != nil; e = (packetElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *packetList) PushFront(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *packetList) PushBack(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *packetList) InsertAfter(b, e *packet) { + bLinker := packetElementMapper{}.linkerFor(b) + eLinker := packetElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *packetList) InsertBefore(a, e *packet) { + aLinker := packetElementMapper{}.linkerFor(a) + eLinker := packetElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *packetList) Remove(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go new file mode 100644 index 000000000..2640a55d0 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -0,0 +1,173 @@ +// automatically generated by stateify. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *packet) StateTypeName() string { + return "pkg/tcpip/transport/packet.packet" +} + +func (p *packet) StateFields() []string { + return []string{ + "packetEntry", + "data", + "receivedAt", + "senderAddr", + "packetInfo", + } +} + +func (p *packet) beforeSave() {} + +// +checklocksignore +func (p *packet) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView + dataValue = p.saveData() + stateSinkObject.SaveValue(1, dataValue) + var receivedAtValue int64 + receivedAtValue = p.saveReceivedAt() + stateSinkObject.SaveValue(2, receivedAtValue) + stateSinkObject.Save(0, &p.packetEntry) + stateSinkObject.Save(3, &p.senderAddr) + stateSinkObject.Save(4, &p.packetInfo) +} + +func (p *packet) afterLoad() {} + +// +checklocksignore +func (p *packet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.packetEntry) + stateSourceObject.Load(3, &p.senderAddr) + stateSourceObject.Load(4, &p.packetInfo) + stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(2, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) }) +} + +func (ep *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/packet.endpoint" +} + +func (ep *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "netProto", + "waiterQueue", + "cooked", + "rcvList", + "rcvBufSize", + "rcvClosed", + "closed", + "bound", + "boundNIC", + "lastError", + "ops", + "frozen", + } +} + +// +checklocksignore +func (ep *endpoint) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + stateSinkObject.Save(0, &ep.TransportEndpointInfo) + stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &ep.netProto) + stateSinkObject.Save(3, &ep.waiterQueue) + stateSinkObject.Save(4, &ep.cooked) + stateSinkObject.Save(5, &ep.rcvList) + stateSinkObject.Save(6, &ep.rcvBufSize) + stateSinkObject.Save(7, &ep.rcvClosed) + stateSinkObject.Save(8, &ep.closed) + stateSinkObject.Save(9, &ep.bound) + stateSinkObject.Save(10, &ep.boundNIC) + stateSinkObject.Save(11, &ep.lastError) + stateSinkObject.Save(12, &ep.ops) + stateSinkObject.Save(13, &ep.frozen) +} + +// +checklocksignore +func (ep *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.TransportEndpointInfo) + stateSourceObject.Load(1, &ep.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &ep.netProto) + stateSourceObject.Load(3, &ep.waiterQueue) + stateSourceObject.Load(4, &ep.cooked) + stateSourceObject.Load(5, &ep.rcvList) + stateSourceObject.Load(6, &ep.rcvBufSize) + stateSourceObject.Load(7, &ep.rcvClosed) + stateSourceObject.Load(8, &ep.closed) + stateSourceObject.Load(9, &ep.bound) + stateSourceObject.Load(10, &ep.boundNIC) + stateSourceObject.Load(11, &ep.lastError) + stateSourceObject.Load(12, &ep.ops) + stateSourceObject.Load(13, &ep.frozen) + stateSourceObject.AfterLoad(ep.afterLoad) +} + +func (l *packetList) StateTypeName() string { + return "pkg/tcpip/transport/packet.packetList" +} + +func (l *packetList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *packetList) beforeSave() {} + +// +checklocksignore +func (l *packetList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *packetList) afterLoad() {} + +// +checklocksignore +func (l *packetList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *packetEntry) StateTypeName() string { + return "pkg/tcpip/transport/packet.packetEntry" +} + +func (e *packetEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *packetEntry) beforeSave() {} + +// +checklocksignore +func (e *packetEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *packetEntry) afterLoad() {} + +// +checklocksignore +func (e *packetEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*packet)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*packetList)(nil)) + state.Register((*packetEntry)(nil)) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD deleted file mode 100644 index 2eab09088..000000000 --- a/pkg/tcpip/transport/raw/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "raw_packet_list", - out = "raw_packet_list.go", - package = "raw", - prefix = "rawPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*rawPacket", - "Linker": "*rawPacket", - }, -) - -go_library( - name = "raw", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "protocol.go", - "raw_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/packet", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go new file mode 100644 index 000000000..48804ff1b --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_packet_list.go @@ -0,0 +1,221 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type rawPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type rawPacketList struct { + head *rawPacket + tail *rawPacket +} + +// Reset resets list l to the empty state. +func (l *rawPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *rawPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *rawPacketList) Front() *rawPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *rawPacketList) Back() *rawPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *rawPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (rawPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *rawPacketList) PushFront(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *rawPacketList) PushBack(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *rawPacketList) PushBackList(m *rawPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *rawPacketList) InsertAfter(b, e *rawPacket) { + bLinker := rawPacketElementMapper{}.linkerFor(b) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + rawPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *rawPacketList) InsertBefore(a, e *rawPacket) { + aLinker := rawPacketElementMapper{}.linkerFor(a) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + rawPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *rawPacketList) Remove(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + rawPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + rawPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type rawPacketEntry struct { + next *rawPacket + prev *rawPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *rawPacketEntry) Next() *rawPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *rawPacketEntry) Prev() *rawPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *rawPacketEntry) SetNext(elem *rawPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *rawPacketEntry) SetPrev(elem *rawPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100644 index 000000000..0d65444d2 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,167 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *rawPacket) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacket" +} + +func (p *rawPacket) StateFields() []string { + return []string{ + "rawPacketEntry", + "data", + "receivedAt", + "senderAddr", + } +} + +func (p *rawPacket) beforeSave() {} + +// +checklocksignore +func (p *rawPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView + dataValue = p.saveData() + stateSinkObject.SaveValue(1, dataValue) + var receivedAtValue int64 + receivedAtValue = p.saveReceivedAt() + stateSinkObject.SaveValue(2, receivedAtValue) + stateSinkObject.Save(0, &p.rawPacketEntry) + stateSinkObject.Save(3, &p.senderAddr) +} + +func (p *rawPacket) afterLoad() {} + +// +checklocksignore +func (p *rawPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.rawPacketEntry) + stateSourceObject.Load(3, &p.senderAddr) + stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(2, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/raw.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "associated", + "rcvList", + "rcvBufSize", + "rcvClosed", + "closed", + "connected", + "bound", + "owner", + "ops", + "frozen", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.associated) + stateSinkObject.Save(4, &e.rcvList) + stateSinkObject.Save(5, &e.rcvBufSize) + stateSinkObject.Save(6, &e.rcvClosed) + stateSinkObject.Save(7, &e.closed) + stateSinkObject.Save(8, &e.connected) + stateSinkObject.Save(9, &e.bound) + stateSinkObject.Save(10, &e.owner) + stateSinkObject.Save(11, &e.ops) + stateSinkObject.Save(12, &e.frozen) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.associated) + stateSourceObject.Load(4, &e.rcvList) + stateSourceObject.Load(5, &e.rcvBufSize) + stateSourceObject.Load(6, &e.rcvClosed) + stateSourceObject.Load(7, &e.closed) + stateSourceObject.Load(8, &e.connected) + stateSourceObject.Load(9, &e.bound) + stateSourceObject.Load(10, &e.owner) + stateSourceObject.Load(11, &e.ops) + stateSourceObject.Load(12, &e.frozen) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *rawPacketList) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacketList" +} + +func (l *rawPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *rawPacketList) beforeSave() {} + +// +checklocksignore +func (l *rawPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *rawPacketList) afterLoad() {} + +// +checklocksignore +func (l *rawPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *rawPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacketEntry" +} + +func (e *rawPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *rawPacketEntry) beforeSave() {} + +// +checklocksignore +func (e *rawPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *rawPacketEntry) afterLoad() {} + +// +checklocksignore +func (e *rawPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*rawPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*rawPacketList)(nil)) + state.Register((*rawPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD deleted file mode 100644 index 8436d2cf0..000000000 --- a/pkg/tcpip/transport/tcp/BUILD +++ /dev/null @@ -1,139 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "more_shards") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tcp_segment_list", - out = "tcp_segment_list.go", - package = "tcp", - prefix = "segment", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*segment", - "Linker": "*segment", - }, -) - -go_template_instance( - name = "tcp_endpoint_list", - out = "tcp_endpoint_list.go", - package = "tcp", - prefix = "endpoint", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*endpoint", - "Linker": "*endpoint", - }, -) - -go_library( - name = "tcp", - srcs = [ - "accept.go", - "connect.go", - "connect_unsafe.go", - "cubic.go", - "dispatcher.go", - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "rack.go", - "rcv.go", - "reno.go", - "reno_recovery.go", - "sack.go", - "sack_recovery.go", - "sack_scoreboard.go", - "segment.go", - "segment_heap.go", - "segment_queue.go", - "segment_state.go", - "segment_unsafe.go", - "snd.go", - "tcp_endpoint_list.go", - "tcp_segment_list.go", - "timer.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "tcp_x_test", - size = "medium", - srcs = [ - "dual_stack_test.go", - "sack_scoreboard_test.go", - "tcp_noracedetector_test.go", - "tcp_rack_test.go", - "tcp_sack_test.go", - "tcp_test.go", - "tcp_timestamp_test.go", - ], - shard_count = more_shards, - deps = [ - ":tcp", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/tcp/testing/context", - "//pkg/test/testutil", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcp_test", - size = "small", - srcs = [ - "segment_test.go", - "timer_test.go", - ], - library = ":tcp", - deps = [ - "//pkg/sleep", - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go deleted file mode 100644 index 5342aacfd..000000000 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ /dev/null @@ -1,650 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } -} - -func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.WritableEvents) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetPacket() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv4(t, b, synCheckers...) - - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - )) - checker.IPv4(t, c.GetPacket(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV4MappedConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt to IPv6 address. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.WritableEvents) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetV6Packet() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv6(t, b, synCheckers...) - - tcp := header.TCP(header.IPv6(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - )) - checker.IPv6(t, c.GetV6Packet(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV6Connect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV6: true, - MTU: defaultMTU, - }) - defer c.Cleanup() - - // Create a v6 endpoint but don't set the v6-only TCP option. - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) -} - -func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) -} - -func testV4Accept(t *testing.T, c *context.Context) { - c.SetGSOEnabled(true) - defer c.SetGSOEnabled(false) - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv4(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestAddr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) - } - - var r strings.Reader - data := "Don't panic" - r.Reset(data) - nep.Write(&r, tcpip.WriteOptions{}) - b = c.GetPacket() - tcp = header.IPv4(b).Payload() - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - var addr tcpip.FullAddress - _, _, err := c.EP.Accept(&addr) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(&addr) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if addr.Addr != context.TestV6Addr { - t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) - } -} - -func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func testV4ListenClose(t *testing.T, c *context.Context) { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - const n = 32 - - // Start listening. - if err := c.EP.Listen(n); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - irs := seqnum.Value(789) - for i := uint16(0); i < n; i++ { - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + i, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Each of these ACKs will cause a syn-cookie based connection to be - // accepted and delivered to the listening endpoint. - for i := 0; i < n; i++ { - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - } - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - nep.Close() - c.EP.Close() -} - -func TestV4ListenCloseOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4ListenClose(t, c) -} diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go deleted file mode 100644 index 8a026ec46..000000000 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rcv_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -func TestAcceptable(t *testing.T) { - for _, tt := range []struct { - segSeq seqnum.Value - segLen seqnum.Size - rcvNxt, rcvAcc seqnum.Value - want bool - }{ - // The segment is smaller than the window. - {105, 2, 100, 104, false}, - {105, 2, 101, 105, true}, - {105, 2, 102, 106, true}, - {105, 2, 103, 107, true}, - {105, 2, 104, 108, true}, - {105, 2, 105, 109, true}, - {105, 2, 106, 110, true}, - {105, 2, 107, 111, false}, - - // The segment is larger than the window. - {105, 4, 103, 105, true}, - {105, 4, 104, 106, true}, - {105, 4, 105, 107, true}, - {105, 4, 106, 108, true}, - {105, 4, 107, 109, true}, - {105, 4, 108, 110, true}, - {105, 4, 109, 111, false}, - {105, 4, 110, 112, false}, - - // The segment has no width. - {105, 0, 100, 102, false}, - {105, 0, 101, 103, false}, - {105, 0, 102, 104, false}, - {105, 0, 103, 105, true}, - {105, 0, 104, 106, true}, - {105, 0, 105, 107, true}, - {105, 0, 106, 108, false}, - {105, 0, 107, 109, false}, - - // The receive window has no width. - {105, 2, 103, 103, false}, - {105, 2, 104, 104, false}, - {105, 2, 105, 105, false}, - {105, 2, 106, 106, false}, - {105, 2, 107, 107, false}, - {105, 2, 108, 108, false}, - {105, 2, 109, 109, false}, - } { - if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { - t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) - } - } -} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go deleted file mode 100644 index b4e5ba0df..000000000 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -const smss = 1500 - -func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { - s := tcp.NewSACKScoreboard(smss, iss) - for _, blk := range blocks { - s.Insert(blk) - } - return s -} - -func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, - }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, - }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, - }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, - }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) - } - } - } -} - -func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, - - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, - - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, - - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, - - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) - } - } -} - -func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, - - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, - - // Same as first case above. - {seq: 50, lost: true}, - - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, - - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, - - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, - - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) - } - } -} - -func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } - - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go deleted file mode 100644 index 2e6ea06f5..000000000 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type segmentSizeWants struct { - DataSize int - SegMemSize int -} - -func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) { - t.Helper() - got := segmentSizeWants{ - DataSize: seg.data.Size(), - SegMemSize: seg.segMemSize(), - } - if diff := cmp.Diff(got, want); diff != "" { - t.Errorf("%s differs (-want +got):\n%s", name, diff) - } -} - -func TestSegmentMerge(t *testing.T) { - var clock faketime.NullClock - id := stack.TransportEndpointID{} - seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10)) - defer seg1.decRef() - seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20)) - defer seg2.decRef() - - checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ - DataSize: 10, - SegMemSize: SegSize + 10, - }) - checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ - DataSize: 20, - SegMemSize: SegSize + 20, - }) - - seg1.merge(seg2) - - checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ - DataSize: 30, - SegMemSize: SegSize + 30, - }) - checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ - DataSize: 0, - SegMemSize: SegSize, - }) -} diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go new file mode 100644 index 000000000..a7dc5df81 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go @@ -0,0 +1,221 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type endpointElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type endpointList struct { + head *endpoint + tail *endpoint +} + +// Reset resets list l to the empty state. +func (l *endpointList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *endpointList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *endpointList) Front() *endpoint { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *endpointList) Back() *endpoint { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *endpointList) Len() (count int) { + for e := l.Front(); e != nil; e = (endpointElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *endpointList) PushFront(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + endpointElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *endpointList) PushBack(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *endpointList) PushBackList(m *endpointList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head) + endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *endpointList) InsertAfter(b, e *endpoint) { + bLinker := endpointElementMapper{}.linkerFor(b) + eLinker := endpointElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + endpointElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *endpointList) InsertBefore(a, e *endpoint) { + aLinker := endpointElementMapper{}.linkerFor(a) + eLinker := endpointElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + endpointElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *endpointList) Remove(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + endpointElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + endpointElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type endpointEntry struct { + next *endpoint + prev *endpoint +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *endpointEntry) Next() *endpoint { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *endpointEntry) Prev() *endpoint { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *endpointEntry) SetNext(elem *endpoint) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *endpointEntry) SetPrev(elem *endpoint) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go deleted file mode 100644 index 84fb1c416..000000000 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ /dev/null @@ -1,559 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// These tests are flaky when run under the go race detector due to some -// iterations taking long enough that the retransmit timer can kick in causing -// the congestion window measurements to fail due to extra packets etc. -// -//go:build !race -// +build !race - -package tcp_test - -import ( - "bytes" - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Wait before checking metrics. - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Wait before checking metrics. - metricPollFn = func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data[:len(data)/2]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - r.Reset(data[len(data)/2:]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) - } - - return nil - } - - // Poll when checking metrics. - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go deleted file mode 100644 index 89e9fb886..000000000 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ /dev/null @@ -1,1102 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -const ( - maxPayload = 10 - tsOptionSize = 12 - maxTCPOptionSize = 40 - mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload - latency = 5 * time.Millisecond -) - -func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) { - t.Helper() - opt := tcpip.TCPRecovery(recovery) - if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) - } -} - -// TestRACKUpdate tests the RACK related fields are updated when an ACK is -// received on a SACK enabled connection. -func TestRACKUpdate(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - var xmitTime tcpip.MonotonicTime - probeDone := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.RACKState is what we expect. - if state.Sender.RACKState.XmitTime.Before(xmitTime) { - t.Fatalf("RACK transmit time failed to update when an ACK is received") - } - - gotSeq := state.Sender.RACKState.EndSequence - wantSeq := state.Sender.SndNxt - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } - - if state.Sender.RACKState.RTT == 0 { - t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") - } - close(probeDone) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - data := make([]byte, maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - xmitTime = c.Stack().Clock().NowMonotonic() - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -// TestRACKDetectReorder tests that RACK detects packet reordering. -func TestRACKDetectReorder(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - t.Skipf("Skipping this test as reorder detection does not consider DSACK.") - - var n int - const ackNumToVerify = 2 - probeDone := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - gotSeq := state.Sender.RACKState.FACK - wantSeq := state.Sender.SndNxt - // FACK should be updated to the highest ending sequence number of the - // segment acknowledged most recently. - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } - - n++ - if n < ackNumToVerify { - if state.Sender.RACKState.Reord { - t.Fatalf("RACK reorder detected when there is no reordering") - } - return - } - - if state.Sender.RACKState.Reord == false { - t.Fatalf("RACK reorder detection failed") - } - close(probeDone) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - data := make([]byte, ackNumToVerify*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < ackNumToVerify; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - start := c.IRS.Add(maxPayload + 1) - end := start.Add(maxPayload) - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte { - setStackSACKPermitted(t, c, true) - if !enableRACK { - setStackTCPRecovery(t, c, 0) - } - createConnectedWithSACKAndTS(c) - - data := make([]byte, numPackets*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < numPackets; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - // This delay is added to increase RTT as low RTT can cause TLP - // before sending ACK. - time.Sleep(latency) - } - - return data -} - -const ( - validDSACKDetected = 1 - failedToDetectDSACK = 2 - invalidDSACKDetected = 3 -) - -func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) { - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK detects DSACK. - n++ - if n < numACK { - if state.Sender.RACKState.DSACKSeen { - probeDone <- invalidDSACKDetected - } - return - } - - if !state.Sender.RACKState.DSACKSeen { - probeDone <- failedToDetectDSACK - return - } - probeDone <- validDSACKDetected - }) -} - -// TestRACKTLPRecovery tests that RACK sends a tail loss probe (TLP) in the -// case of a tail loss. This simulates a situation where the TLP is able to -// insinuate the SACK holes and sender is able to retransmit the rest. -func TestRACKTLPRecovery(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - // Send the SACK after RTT because RACK RFC states that if the ACK for a - // retransmission arrives before the smoothed RTT then the sender should not - // update RACK state as it could be a spurious inference. - time.Sleep(info.RTT) - - // Okay, let the sender know we got #8 using a SACK block. - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // The sender should be entering RACK based loss-recovery and sending #6 and - // #7 one after another. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += 2 * maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // One fast retransmit after the SACK. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - // Recovery should be SACK recovery. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // Packets 6, 7 and 8 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 3}, - // TLP recovery should have been detected. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1}, - // No RTOs should have occurred. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKTLPFallbackRTO tests that RACK sends a tail loss probe (TLP) in the -// case of a tail loss. This simulates a situation where either the TLP or its -// ACK is lost. The sender should retransmit when RTO fires. -func TestRACKTLPFallbackRTO(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - - // Either the TLP or the ACK the receiver sent with SACK blocks was lost. - - // Confirm that RTO fires and retransmits packet #6. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // No fast retransmits happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - // No SACK recovery happened. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - // TLP was unsuccessful. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestNoTLPRecoveryOnDSACK tests the scenario where the sender speculates a -// tail loss and sends a TLP. Everything is received and acked. The probe -// segment is DSACKed. No fast recovery should be triggered in this case. -func TestNoTLPRecoveryOnDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-5] are received first. [6-8] took a detour and will take a - // while to arrive. Ack [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // The tail loss probe (#8 packet) is received. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - - // Now that all 8 packets are received + duplicate 8th packet, send ack. - bytesRead += 3 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Wait for RTO and make sure that nothing else is received. - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - if p := c.GetPacketWithTimeout(info.RTO); p != nil { - t.Errorf("received an unexpected packet: %v", p) - } - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Make sure no recovery was entered. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #8 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestNoTLPOnSACK tests the scenario where there is not exactly a tail loss -// due to the presence of multiple SACK holes. In such a scenario, TLP should -// not be sent. -func TestNoTLPOnSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-5] and #7 were received. #6 and #8 were dropped. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhEnd := seventhStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}}) - - // The sender should retransmit #6. If the sender sends a TLP, then #8 will - // received and fail this test. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #6 was retransmitted due to SACK recovery. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #6 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKOnePacketTailLoss tests the trivial case of a tail loss of only one -// packet. The probe should itself repairs the loss instead of having to go -// into any recovery. -func TestRACKOnePacketTailLoss(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 3 packets. - numPackets := 3 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-2] are received. #3 is lost. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 2 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #3 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #3 was retransmitted as TLP. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #3 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.1. -func TestRACKDetectDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 2 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Expect retransmission of #6 packet after RTO expires. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-8] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 3 * maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Check DSACK was received for one segment. - {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of -// order segments. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.2. -func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 2 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 10 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - // Send DSACK block for #6 along with SACK for out of - // order #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a -// duplicate of out of order packet. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.3 -func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 4 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 10 - sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // ACK [1-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // Send SACK indicating #6 packet is missing and received #7 packet. - offset := seqnum.Size(bytesRead + maxPayload) - start := c.IRS.Add(1 + offset) - end := start.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Send SACK with #6 packet is missing and received [7-8] packets. - end = start.Add(2 * maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Consider #8 packet is duplicated on the network and send DSACK. - dsackStart := c.IRS.Add(1 + offset + maxPayload) - dsackEnd := dsackStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.1. -func TestRACKDetectDSACKSingleDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 4 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 4 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-3] packets and received #4 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // ACK for retransmitted #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by - // sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Check DSACK was received for a subsegment. - {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous -// duplicate subsegments covered by the cumulative acknowledgement. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.2. -func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 5 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 6 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-5] packets and received #6 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Received delayed #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #4 packet. - start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 - // packet by sending DSACK block for #2 packet. - dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not -// covered by the cumulative acknowledgement. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.3. -func TestRACKDetectDSACKDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 5 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 7 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #3 packet. - start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Consider #2 packet has been dropped and SACK #4 packet. - start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end2 := start2.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 - // packet by sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - end1 = end1.Add(seqnum.Size(2 * maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK -// is not the first SACK block. -func TestRACKWithInvalidDSACKBlock(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan struct{}) - const ackNumToVerify = 2 - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK does not detect DSACK when DSACK block is - // not the first SACK block. - n++ - t.Helper() - if state.Sender.RACKState.DSACKSeen { - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } - - if n == ackNumToVerify { - close(probeDone) - } - }) - - numPackets := 10 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - - // Send DSACK block as second block. The first block is a SACK for #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - <-probeDone -} - -func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK detects DSACK. - n++ - if n < numACK { - return - } - - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT) - return - } - - if state.Sender.RACKState.ReoWndIncr != 1 { - probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) - return - } - - if state.Sender.RACKState.ReoWndPersist > 0 { - probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) - return - } - probeDone <- nil - }) -} - -func TestRACKCheckReorderWindow(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan error) - const ackNumToVerify = 3 - addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) - - const numPackets = 7 - sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed packets [2-6] which indicates there is reordering - // in the connection. - bytesRead += 6 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - if err := <-probeDone; err != nil { - t.Fatalf("unexpected values for RACK variables: %v", err) - } -} - -func TestRACKWithDuplicateACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - const numPackets = 4 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send three duplicate ACKs to trigger fast recovery. The first - // segment is considered as lost and will be retransmitted after - // receiving the duplicate ACKs. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - end = end.Add(seqnum.Size(maxPayload)) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK -// is received. -func TestRACKUpdateSackedOut(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan struct{}) - ackNum := 0 - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.SackedOut is what we expect. - if state.Sender.SackedOut != 2 && ackNum == 0 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) - } - - if state.Sender.SackedOut != 0 && ackNum == 1 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) - } - if ackNum > 0 { - close(probeDone) - } - ackNum++ - }) - - sendAndReceiveWithSACK(t, c, 8, true /* enableRACK */) - - // ACK for [3-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) - bytesRead := 2 * maxPayload - end := start.Add(seqnum.Size(bytesRead)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - bytesRead += 3 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -// TestRACKWithWindowFull tests that RACK honors the receive window size. -func TestRACKWithWindowFull(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - const numPkts = 10 - data := make([]byte, numPkts*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < numPkts; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - if i == 0 { - // Send ACK for the first packet to establish RTT. - c.SendAck(seq, maxPayload) - } - } - - // SACK for #10 packet. - start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}}) - - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - // Wait for RTT to trigger recovery. - time.Sleep(info.RTT) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize) - - // Send ACK for #2 packet. - c.SendAck(seq, 3*maxPayload) - - // Expect retransmission of #3 packet. - c.ReceiveAndCheckPacketWithOptions(data, 3*maxPayload, maxPayload, tsOptionSize) - - // Send ACK with zero window size. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + 4*maxPayload), - RcvWnd: 0, - }) - - // No packet should be received as the receive window size is zero. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") -} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go deleted file mode 100644 index 83e0653b9..000000000 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ /dev/null @@ -1,704 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "log" - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -// createConnectedWithSACKPermittedOption creates and connects c.ep with the -// SACKPermitted option enabled if the stack in the context has the SACK support -// enabled. -func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) -} - -// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS -// option enabled if the stack in the context has SACK and TS enabled. -func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) -} - -func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { - t.Helper() - opt := tcpip.TCPSACKEnabled(enable) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } -} - -// TestSackPermittedConnect establishes a connection with the SACK option -// enabled. -func TestSackPermittedConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - setStackTCPRecovery(t, c, 0) - rep := createConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackDisabledConnect establishes a connection with the SACK option -// disabled and verifies that no SACKs are sent for out of order segments. -func TestSackDisabledConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - setStackTCPRecovery(t, c, 0) - - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) - - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackPermittedAccept accepts and establishes a connection with the -// SACKPermitted option enabled if the connection request specifies the -// SACKPermitted option. In case of SYN cookies SACK should be disabled as we -// don't encode the SACK information in the cookie. -func TestSackPermittedAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - sackPermitted bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - setStackSACKPermitted(t, c, sackEnabled) - setStackTCPRecovery(t, c, 0) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -// TestSackDisabledAccept accepts and establishes a connection with -// the SACKPermitted option disabled and verifies that no SACKs are -// sent for out of order packets. -func TestSackDisabledAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - setStackSACKPermitted(t, c, sackEnabled) - setStackTCPRecovery(t, c, 0) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -func TestUpdateSACKBlocks(t *testing.T) { - testCases := []struct { - segStart seqnum.Value - segEnd seqnum.Value - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - updated []header.SACKBlock - }{ - // Trivial cases where current SACK block list is empty and we - // have an out of order delivery. - {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, - {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, - {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, - - // Cases where current SACK block list is not empty and we have - // an out of order delivery. Tests that the updated SACK block - // list has the first block as the one that contains the new - // SACK block representing the segment that was just delivered. - {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, - - // Ensure that we only retain header.MaxSACKBlocks and drop the - // oldest one if adding a new block exceeds - // header.MaxSACKBlocks. - {24, 30, 9, - []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, - []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, - - // Cases where segment extends an existing SACK block. - {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, - {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, - - // Cases where segment contains rcvNxt. - {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, - } - - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { - t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) - } - - } -} - -func TestTrimSackBlockList(t *testing.T) { - testCases := []struct { - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - trimmed []header.SACKBlock - }{ - // Simple cases where we trim whole entries. - {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, - {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, - {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, - {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - // Cases where we need to update a block. - {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, - {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, - {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, - {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - } - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.TrimSACKBlockList(&sack, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { - t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) - } - } -} - -func TestSACKRecovery(t *testing.T) { - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by 40 bytes to account for SACK and Timestamp - // options. - const maxTCPOptionSize = 40 - - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) - defer c.Cleanup() - - c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - }) - setStackSACKPermitted(t, c, true) - setStackTCPRecovery(t, c, 0) - createConnectedWithSACKAndTS(c) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(seq, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } - - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - metricPollFn = func() error { - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(seq, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(seq, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 - } - expected++ - } -} - -// TestRecoveryEntry tests the following two properties of entering recovery: -// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK -// scoreboard but dupack count is still below threshold. -// - Only enter recovery when at least one more byte of data beyond the highest -// byte that was outstanding when fast retransmit was last entered is acked. -func TestRecoveryEntry(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - numPackets := 5 - data := sendAndReceiveWithSACK(t, c, numPackets, false /* enableRACK */) - - // Ack #1 packet. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, maxPayload) - - // Now SACK #3, #4 and #5 packets. This will simulate a situation where - // SND.UNA should be considered lost and the sender should enter fast recovery - // (even though dupack count is still below threshold). - p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - p3End := p3Start.Add(maxPayload) - p4Start := p3End - p4End := p4Start.Add(maxPayload) - p5Start := p4End - p5End := p5Start.Add(maxPayload) - c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}}) - - // Expect #2 to be retransmitted. - c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - // No RTOs should have fired yet. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Send 4 more packets. - var r bytes.Reader - data = append(data, data...) - r.Reset(data[5*maxPayload : 9*maxPayload]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - var sackBlocks []header.SACKBlock - bytesRead := numPackets * maxPayload - for i := 0; i < 4; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - if i > 0 { - pStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)}) - c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks) - } - bytesRead += maxPayload - } - - // #6 should be retransmitted after RTO. The sender should NOT enter fast - // recovery because the highest byte that was outstanding when fast recovery - // was last entered is #5 packet's end. And the sender requires at least one - // more byte beyond that (#6 packet start) to be acked to enter recovery. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - c.SendAck(seq, 9*maxPayload) - - metricPollFn = func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Only 1 SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 and #6 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 2}, - // RTO should have fired once. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100644 index 000000000..a14cff27e --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,221 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *segmentList) Back() *segment { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *segmentList) Len() (count int) { + for e := l.Front(); e != nil; e = (segmentElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *segmentList) PushFront(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *segmentList) PushBack(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *segmentList) InsertAfter(b, e *segment) { + bLinker := segmentElementMapper{}.linkerFor(b) + eLinker := segmentElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *segmentList) InsertBefore(a, e *segment) { + aLinker := segmentElementMapper{}.linkerFor(a) + eLinker := segmentElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *segmentList) Remove(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100644 index 000000000..13061d2b1 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,901 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (c *cubicState) StateTypeName() string { + return "pkg/tcpip/transport/tcp.cubicState" +} + +func (c *cubicState) StateFields() []string { + return []string{ + "TCPCubicState", + "numCongestionEvents", + "s", + } +} + +func (c *cubicState) beforeSave() {} + +// +checklocksignore +func (c *cubicState) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.TCPCubicState) + stateSinkObject.Save(1, &c.numCongestionEvents) + stateSinkObject.Save(2, &c.s) +} + +func (c *cubicState) afterLoad() {} + +// +checklocksignore +func (c *cubicState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.TCPCubicState) + stateSourceObject.Load(1, &c.numCongestionEvents) + stateSourceObject.Load(2, &c.s) +} + +func (s *SACKInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.SACKInfo" +} + +func (s *SACKInfo) StateFields() []string { + return []string{ + "Blocks", + "NumBlocks", + } +} + +func (s *SACKInfo) beforeSave() {} + +// +checklocksignore +func (s *SACKInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Blocks) + stateSinkObject.Save(1, &s.NumBlocks) +} + +func (s *SACKInfo) afterLoad() {} + +// +checklocksignore +func (s *SACKInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Blocks) + stateSourceObject.Load(1, &s.NumBlocks) +} + +func (s *sndQueueInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sndQueueInfo" +} + +func (s *sndQueueInfo) StateFields() []string { + return []string{ + "TCPSndBufState", + } +} + +func (s *sndQueueInfo) beforeSave() {} + +// +checklocksignore +func (s *sndQueueInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.TCPSndBufState) +} + +func (s *sndQueueInfo) afterLoad() {} + +// +checklocksignore +func (s *sndQueueInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.TCPSndBufState) +} + +func (r *rcvQueueInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rcvQueueInfo" +} + +func (r *rcvQueueInfo) StateFields() []string { + return []string{ + "TCPRcvBufState", + "rcvQueue", + } +} + +func (r *rcvQueueInfo) beforeSave() {} + +// +checklocksignore +func (r *rcvQueueInfo) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.TCPRcvBufState) + stateSinkObject.Save(1, &r.rcvQueue) +} + +func (r *rcvQueueInfo) afterLoad() {} + +// +checklocksignore +func (r *rcvQueueInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.TCPRcvBufState) + stateSourceObject.LoadWait(1, &r.rcvQueue) +} + +func (a *accepted) StateTypeName() string { + return "pkg/tcpip/transport/tcp.accepted" +} + +func (a *accepted) StateFields() []string { + return []string{ + "endpoints", + "cap", + } +} + +func (a *accepted) beforeSave() {} + +// +checklocksignore +func (a *accepted) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + var endpointsValue []*endpoint + endpointsValue = a.saveEndpoints() + stateSinkObject.SaveValue(0, endpointsValue) + stateSinkObject.Save(1, &a.cap) +} + +func (a *accepted) afterLoad() {} + +// +checklocksignore +func (a *accepted) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(1, &a.cap) + stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TCPEndpointStateInner", + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "hardError", + "lastError", + "rcvQueueInfo", + "rcvMemUsed", + "ownedByUser", + "state", + "boundNICID", + "ttl", + "isConnectNotified", + "portFlags", + "boundBindToDevice", + "boundPortFlags", + "boundDest", + "effectiveNetProtos", + "workerRunning", + "workerCleanup", + "recentTSTime", + "shutdownFlags", + "tcpRecovery", + "sack", + "delay", + "scoreboard", + "segmentQueue", + "synRcvdCount", + "userMSS", + "maxSynRetries", + "windowClamp", + "sndQueueInfo", + "cc", + "keepalive", + "userTimeout", + "deferAccept", + "accepted", + "rcv", + "snd", + "connectingAddress", + "amss", + "sendTOS", + "gso", + "tcpLingerTimeout", + "closed", + "txHash", + "owner", + "ops", + "lastOutOfWindowAckTime", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var stateValue EndpointState + stateValue = e.saveState() + stateSinkObject.SaveValue(10, stateValue) + stateSinkObject.Save(0, &e.TCPEndpointStateInner) + stateSinkObject.Save(1, &e.TransportEndpointInfo) + stateSinkObject.Save(2, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(3, &e.waiterQueue) + stateSinkObject.Save(4, &e.uniqueID) + stateSinkObject.Save(5, &e.hardError) + stateSinkObject.Save(6, &e.lastError) + stateSinkObject.Save(7, &e.rcvQueueInfo) + stateSinkObject.Save(8, &e.rcvMemUsed) + stateSinkObject.Save(9, &e.ownedByUser) + stateSinkObject.Save(11, &e.boundNICID) + stateSinkObject.Save(12, &e.ttl) + stateSinkObject.Save(13, &e.isConnectNotified) + stateSinkObject.Save(14, &e.portFlags) + stateSinkObject.Save(15, &e.boundBindToDevice) + stateSinkObject.Save(16, &e.boundPortFlags) + stateSinkObject.Save(17, &e.boundDest) + stateSinkObject.Save(18, &e.effectiveNetProtos) + stateSinkObject.Save(19, &e.workerRunning) + stateSinkObject.Save(20, &e.workerCleanup) + stateSinkObject.Save(21, &e.recentTSTime) + stateSinkObject.Save(22, &e.shutdownFlags) + stateSinkObject.Save(23, &e.tcpRecovery) + stateSinkObject.Save(24, &e.sack) + stateSinkObject.Save(25, &e.delay) + stateSinkObject.Save(26, &e.scoreboard) + stateSinkObject.Save(27, &e.segmentQueue) + stateSinkObject.Save(28, &e.synRcvdCount) + stateSinkObject.Save(29, &e.userMSS) + stateSinkObject.Save(30, &e.maxSynRetries) + stateSinkObject.Save(31, &e.windowClamp) + stateSinkObject.Save(32, &e.sndQueueInfo) + stateSinkObject.Save(33, &e.cc) + stateSinkObject.Save(34, &e.keepalive) + stateSinkObject.Save(35, &e.userTimeout) + stateSinkObject.Save(36, &e.deferAccept) + stateSinkObject.Save(37, &e.accepted) + stateSinkObject.Save(38, &e.rcv) + stateSinkObject.Save(39, &e.snd) + stateSinkObject.Save(40, &e.connectingAddress) + stateSinkObject.Save(41, &e.amss) + stateSinkObject.Save(42, &e.sendTOS) + stateSinkObject.Save(43, &e.gso) + stateSinkObject.Save(44, &e.tcpLingerTimeout) + stateSinkObject.Save(45, &e.closed) + stateSinkObject.Save(46, &e.txHash) + stateSinkObject.Save(47, &e.owner) + stateSinkObject.Save(48, &e.ops) + stateSinkObject.Save(49, &e.lastOutOfWindowAckTime) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TCPEndpointStateInner) + stateSourceObject.Load(1, &e.TransportEndpointInfo) + stateSourceObject.Load(2, &e.DefaultSocketOptionsHandler) + stateSourceObject.LoadWait(3, &e.waiterQueue) + stateSourceObject.Load(4, &e.uniqueID) + stateSourceObject.Load(5, &e.hardError) + stateSourceObject.Load(6, &e.lastError) + stateSourceObject.Load(7, &e.rcvQueueInfo) + stateSourceObject.Load(8, &e.rcvMemUsed) + stateSourceObject.Load(9, &e.ownedByUser) + stateSourceObject.Load(11, &e.boundNICID) + stateSourceObject.Load(12, &e.ttl) + stateSourceObject.Load(13, &e.isConnectNotified) + stateSourceObject.Load(14, &e.portFlags) + stateSourceObject.Load(15, &e.boundBindToDevice) + stateSourceObject.Load(16, &e.boundPortFlags) + stateSourceObject.Load(17, &e.boundDest) + stateSourceObject.Load(18, &e.effectiveNetProtos) + stateSourceObject.Load(19, &e.workerRunning) + stateSourceObject.Load(20, &e.workerCleanup) + stateSourceObject.Load(21, &e.recentTSTime) + stateSourceObject.Load(22, &e.shutdownFlags) + stateSourceObject.Load(23, &e.tcpRecovery) + stateSourceObject.Load(24, &e.sack) + stateSourceObject.Load(25, &e.delay) + stateSourceObject.Load(26, &e.scoreboard) + stateSourceObject.LoadWait(27, &e.segmentQueue) + stateSourceObject.Load(28, &e.synRcvdCount) + stateSourceObject.Load(29, &e.userMSS) + stateSourceObject.Load(30, &e.maxSynRetries) + stateSourceObject.Load(31, &e.windowClamp) + stateSourceObject.Load(32, &e.sndQueueInfo) + stateSourceObject.Load(33, &e.cc) + stateSourceObject.Load(34, &e.keepalive) + stateSourceObject.Load(35, &e.userTimeout) + stateSourceObject.Load(36, &e.deferAccept) + stateSourceObject.Load(37, &e.accepted) + stateSourceObject.LoadWait(38, &e.rcv) + stateSourceObject.LoadWait(39, &e.snd) + stateSourceObject.Load(40, &e.connectingAddress) + stateSourceObject.Load(41, &e.amss) + stateSourceObject.Load(42, &e.sendTOS) + stateSourceObject.Load(43, &e.gso) + stateSourceObject.Load(44, &e.tcpLingerTimeout) + stateSourceObject.Load(45, &e.closed) + stateSourceObject.Load(46, &e.txHash) + stateSourceObject.Load(47, &e.owner) + stateSourceObject.Load(48, &e.ops) + stateSourceObject.Load(49, &e.lastOutOfWindowAckTime) + stateSourceObject.LoadValue(10, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (k *keepalive) StateTypeName() string { + return "pkg/tcpip/transport/tcp.keepalive" +} + +func (k *keepalive) StateFields() []string { + return []string{ + "idle", + "interval", + "count", + "unacked", + } +} + +func (k *keepalive) beforeSave() {} + +// +checklocksignore +func (k *keepalive) StateSave(stateSinkObject state.Sink) { + k.beforeSave() + stateSinkObject.Save(0, &k.idle) + stateSinkObject.Save(1, &k.interval) + stateSinkObject.Save(2, &k.count) + stateSinkObject.Save(3, &k.unacked) +} + +func (k *keepalive) afterLoad() {} + +// +checklocksignore +func (k *keepalive) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &k.idle) + stateSourceObject.Load(1, &k.interval) + stateSourceObject.Load(2, &k.count) + stateSourceObject.Load(3, &k.unacked) +} + +func (rc *rackControl) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rackControl" +} + +func (rc *rackControl) StateFields() []string { + return []string{ + "TCPRACKState", + "exitedRecovery", + "minRTT", + "tlpRxtOut", + "tlpHighRxt", + "snd", + } +} + +func (rc *rackControl) beforeSave() {} + +// +checklocksignore +func (rc *rackControl) StateSave(stateSinkObject state.Sink) { + rc.beforeSave() + stateSinkObject.Save(0, &rc.TCPRACKState) + stateSinkObject.Save(1, &rc.exitedRecovery) + stateSinkObject.Save(2, &rc.minRTT) + stateSinkObject.Save(3, &rc.tlpRxtOut) + stateSinkObject.Save(4, &rc.tlpHighRxt) + stateSinkObject.Save(5, &rc.snd) +} + +func (rc *rackControl) afterLoad() {} + +// +checklocksignore +func (rc *rackControl) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rc.TCPRACKState) + stateSourceObject.Load(1, &rc.exitedRecovery) + stateSourceObject.Load(2, &rc.minRTT) + stateSourceObject.Load(3, &rc.tlpRxtOut) + stateSourceObject.Load(4, &rc.tlpHighRxt) + stateSourceObject.Load(5, &rc.snd) +} + +func (r *receiver) StateTypeName() string { + return "pkg/tcpip/transport/tcp.receiver" +} + +func (r *receiver) StateFields() []string { + return []string{ + "TCPReceiverState", + "ep", + "rcvWnd", + "rcvWUP", + "prevBufUsed", + "closed", + "pendingRcvdSegments", + "lastRcvdAckTime", + } +} + +func (r *receiver) beforeSave() {} + +// +checklocksignore +func (r *receiver) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.TCPReceiverState) + stateSinkObject.Save(1, &r.ep) + stateSinkObject.Save(2, &r.rcvWnd) + stateSinkObject.Save(3, &r.rcvWUP) + stateSinkObject.Save(4, &r.prevBufUsed) + stateSinkObject.Save(5, &r.closed) + stateSinkObject.Save(6, &r.pendingRcvdSegments) + stateSinkObject.Save(7, &r.lastRcvdAckTime) +} + +func (r *receiver) afterLoad() {} + +// +checklocksignore +func (r *receiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.TCPReceiverState) + stateSourceObject.Load(1, &r.ep) + stateSourceObject.Load(2, &r.rcvWnd) + stateSourceObject.Load(3, &r.rcvWUP) + stateSourceObject.Load(4, &r.prevBufUsed) + stateSourceObject.Load(5, &r.closed) + stateSourceObject.Load(6, &r.pendingRcvdSegments) + stateSourceObject.Load(7, &r.lastRcvdAckTime) +} + +func (r *renoState) StateTypeName() string { + return "pkg/tcpip/transport/tcp.renoState" +} + +func (r *renoState) StateFields() []string { + return []string{ + "s", + } +} + +func (r *renoState) beforeSave() {} + +// +checklocksignore +func (r *renoState) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.s) +} + +func (r *renoState) afterLoad() {} + +// +checklocksignore +func (r *renoState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.s) +} + +func (rr *renoRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.renoRecovery" +} + +func (rr *renoRecovery) StateFields() []string { + return []string{ + "s", + } +} + +func (rr *renoRecovery) beforeSave() {} + +// +checklocksignore +func (rr *renoRecovery) StateSave(stateSinkObject state.Sink) { + rr.beforeSave() + stateSinkObject.Save(0, &rr.s) +} + +func (rr *renoRecovery) afterLoad() {} + +// +checklocksignore +func (rr *renoRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rr.s) +} + +func (sr *sackRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sackRecovery" +} + +func (sr *sackRecovery) StateFields() []string { + return []string{ + "s", + } +} + +func (sr *sackRecovery) beforeSave() {} + +// +checklocksignore +func (sr *sackRecovery) StateSave(stateSinkObject state.Sink) { + sr.beforeSave() + stateSinkObject.Save(0, &sr.s) +} + +func (sr *sackRecovery) afterLoad() {} + +// +checklocksignore +func (sr *sackRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sr.s) +} + +func (s *SACKScoreboard) StateTypeName() string { + return "pkg/tcpip/transport/tcp.SACKScoreboard" +} + +func (s *SACKScoreboard) StateFields() []string { + return []string{ + "smss", + "maxSACKED", + } +} + +func (s *SACKScoreboard) beforeSave() {} + +// +checklocksignore +func (s *SACKScoreboard) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.smss) + stateSinkObject.Save(1, &s.maxSACKED) +} + +func (s *SACKScoreboard) afterLoad() {} + +// +checklocksignore +func (s *SACKScoreboard) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.smss) + stateSourceObject.Load(1, &s.maxSACKED) +} + +func (s *segment) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segment" +} + +func (s *segment) StateFields() []string { + return []string{ + "segmentEntry", + "refCnt", + "ep", + "qFlags", + "srcAddr", + "dstAddr", + "netProto", + "nicID", + "data", + "hdr", + "sequenceNumber", + "ackNumber", + "flags", + "window", + "csum", + "csumValid", + "parsedOptions", + "options", + "hasNewSACKInfo", + "rcvdTime", + "xmitTime", + "xmitCount", + "acked", + "dataMemSize", + "lost", + } +} + +func (s *segment) beforeSave() {} + +// +checklocksignore +func (s *segment) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var dataValue buffer.VectorisedView + dataValue = s.saveData() + stateSinkObject.SaveValue(8, dataValue) + var optionsValue []byte + optionsValue = s.saveOptions() + stateSinkObject.SaveValue(17, optionsValue) + stateSinkObject.Save(0, &s.segmentEntry) + stateSinkObject.Save(1, &s.refCnt) + stateSinkObject.Save(2, &s.ep) + stateSinkObject.Save(3, &s.qFlags) + stateSinkObject.Save(4, &s.srcAddr) + stateSinkObject.Save(5, &s.dstAddr) + stateSinkObject.Save(6, &s.netProto) + stateSinkObject.Save(7, &s.nicID) + stateSinkObject.Save(9, &s.hdr) + stateSinkObject.Save(10, &s.sequenceNumber) + stateSinkObject.Save(11, &s.ackNumber) + stateSinkObject.Save(12, &s.flags) + stateSinkObject.Save(13, &s.window) + stateSinkObject.Save(14, &s.csum) + stateSinkObject.Save(15, &s.csumValid) + stateSinkObject.Save(16, &s.parsedOptions) + stateSinkObject.Save(18, &s.hasNewSACKInfo) + stateSinkObject.Save(19, &s.rcvdTime) + stateSinkObject.Save(20, &s.xmitTime) + stateSinkObject.Save(21, &s.xmitCount) + stateSinkObject.Save(22, &s.acked) + stateSinkObject.Save(23, &s.dataMemSize) + stateSinkObject.Save(24, &s.lost) +} + +func (s *segment) afterLoad() {} + +// +checklocksignore +func (s *segment) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.segmentEntry) + stateSourceObject.Load(1, &s.refCnt) + stateSourceObject.Load(2, &s.ep) + stateSourceObject.Load(3, &s.qFlags) + stateSourceObject.Load(4, &s.srcAddr) + stateSourceObject.Load(5, &s.dstAddr) + stateSourceObject.Load(6, &s.netProto) + stateSourceObject.Load(7, &s.nicID) + stateSourceObject.Load(9, &s.hdr) + stateSourceObject.Load(10, &s.sequenceNumber) + stateSourceObject.Load(11, &s.ackNumber) + stateSourceObject.Load(12, &s.flags) + stateSourceObject.Load(13, &s.window) + stateSourceObject.Load(14, &s.csum) + stateSourceObject.Load(15, &s.csumValid) + stateSourceObject.Load(16, &s.parsedOptions) + stateSourceObject.Load(18, &s.hasNewSACKInfo) + stateSourceObject.Load(19, &s.rcvdTime) + stateSourceObject.Load(20, &s.xmitTime) + stateSourceObject.Load(21, &s.xmitCount) + stateSourceObject.Load(22, &s.acked) + stateSourceObject.Load(23, &s.dataMemSize) + stateSourceObject.Load(24, &s.lost) + stateSourceObject.LoadValue(8, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(17, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) }) +} + +func (q *segmentQueue) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentQueue" +} + +func (q *segmentQueue) StateFields() []string { + return []string{ + "list", + "ep", + "frozen", + } +} + +func (q *segmentQueue) beforeSave() {} + +// +checklocksignore +func (q *segmentQueue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.list) + stateSinkObject.Save(1, &q.ep) + stateSinkObject.Save(2, &q.frozen) +} + +func (q *segmentQueue) afterLoad() {} + +// +checklocksignore +func (q *segmentQueue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &q.list) + stateSourceObject.Load(1, &q.ep) + stateSourceObject.Load(2, &q.frozen) +} + +func (s *sender) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sender" +} + +func (s *sender) StateFields() []string { + return []string{ + "TCPSenderState", + "ep", + "lr", + "firstRetransmittedSegXmitTime", + "writeNext", + "writeList", + "rtt", + "minRTO", + "maxRTO", + "maxRetries", + "gso", + "state", + "cc", + "rc", + } +} + +func (s *sender) beforeSave() {} + +// +checklocksignore +func (s *sender) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.TCPSenderState) + stateSinkObject.Save(1, &s.ep) + stateSinkObject.Save(2, &s.lr) + stateSinkObject.Save(3, &s.firstRetransmittedSegXmitTime) + stateSinkObject.Save(4, &s.writeNext) + stateSinkObject.Save(5, &s.writeList) + stateSinkObject.Save(6, &s.rtt) + stateSinkObject.Save(7, &s.minRTO) + stateSinkObject.Save(8, &s.maxRTO) + stateSinkObject.Save(9, &s.maxRetries) + stateSinkObject.Save(10, &s.gso) + stateSinkObject.Save(11, &s.state) + stateSinkObject.Save(12, &s.cc) + stateSinkObject.Save(13, &s.rc) +} + +func (s *sender) afterLoad() {} + +// +checklocksignore +func (s *sender) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.TCPSenderState) + stateSourceObject.Load(1, &s.ep) + stateSourceObject.Load(2, &s.lr) + stateSourceObject.Load(3, &s.firstRetransmittedSegXmitTime) + stateSourceObject.Load(4, &s.writeNext) + stateSourceObject.Load(5, &s.writeList) + stateSourceObject.Load(6, &s.rtt) + stateSourceObject.Load(7, &s.minRTO) + stateSourceObject.Load(8, &s.maxRTO) + stateSourceObject.Load(9, &s.maxRetries) + stateSourceObject.Load(10, &s.gso) + stateSourceObject.Load(11, &s.state) + stateSourceObject.Load(12, &s.cc) + stateSourceObject.Load(13, &s.rc) +} + +func (r *rtt) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rtt" +} + +func (r *rtt) StateFields() []string { + return []string{ + "TCPRTTState", + } +} + +func (r *rtt) beforeSave() {} + +// +checklocksignore +func (r *rtt) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.TCPRTTState) +} + +func (r *rtt) afterLoad() {} + +// +checklocksignore +func (r *rtt) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.TCPRTTState) +} + +func (l *endpointList) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpointList" +} + +func (l *endpointList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *endpointList) beforeSave() {} + +// +checklocksignore +func (l *endpointList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *endpointList) afterLoad() {} + +// +checklocksignore +func (l *endpointList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *endpointEntry) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpointEntry" +} + +func (e *endpointEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *endpointEntry) beforeSave() {} + +// +checklocksignore +func (e *endpointEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *endpointEntry) afterLoad() {} + +// +checklocksignore +func (e *endpointEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (l *segmentList) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentList" +} + +func (l *segmentList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *segmentList) beforeSave() {} + +// +checklocksignore +func (l *segmentList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *segmentList) afterLoad() {} + +// +checklocksignore +func (l *segmentList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *segmentEntry) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentEntry" +} + +func (e *segmentEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *segmentEntry) beforeSave() {} + +// +checklocksignore +func (e *segmentEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *segmentEntry) afterLoad() {} + +// +checklocksignore +func (e *segmentEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*cubicState)(nil)) + state.Register((*SACKInfo)(nil)) + state.Register((*sndQueueInfo)(nil)) + state.Register((*rcvQueueInfo)(nil)) + state.Register((*accepted)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*keepalive)(nil)) + state.Register((*rackControl)(nil)) + state.Register((*receiver)(nil)) + state.Register((*renoState)(nil)) + state.Register((*renoRecovery)(nil)) + state.Register((*sackRecovery)(nil)) + state.Register((*SACKScoreboard)(nil)) + state.Register((*segment)(nil)) + state.Register((*segmentQueue)(nil)) + state.Register((*sender)(nil)) + state.Register((*rtt)(nil)) + state.Register((*endpointList)(nil)) + state.Register((*endpointEntry)(nil)) + state.Register((*segmentList)(nil)) + state.Register((*segmentEntry)(nil)) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go deleted file mode 100644 index db6b0955a..000000000 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ /dev/null @@ -1,8058 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "io/ioutil" - "math" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/pkg/waiter" -) - -// endpointTester provides helper functions to test a tcpip.Endpoint. -type endpointTester struct { - ep tcpip.Endpoint -} - -// CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { - t.Helper() - res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if got != want { - t.Fatalf("ep.Read = %s, want %s", got, want) - } - if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" { - t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff) - } -} - -// CheckRead issues a read to the endpoint and checking for a success, returning -// the data read. -func (e *endpointTester) CheckRead(t *testing.T) []byte { - t.Helper() - var buf bytes.Buffer - res, err := e.ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("ep.Read = _, %s; want _, nil", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - return buf.Bytes() -} - -// CheckReadFull reads from the endpoint for exactly count bytes. -func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { - t.Helper() - var buf bytes.Buffer - w := tcpip.LimitedWriter{ - W: &buf, - N: int64(count), - } - for w.N != 0 { - _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for receive to be notified. - select { - case <-notifyRead: - case <-time.After(timeout): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } else if err != nil { - t.Fatalf("ep.Read = _, %s; want _, nil", err) - } - } - return buf.Bytes() -} - -const ( - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65535 - - // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an - // IPv4 endpoint when the MTU is set to defaultMTU in the test. - defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize -) - -func TestGiveUpConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventHUp) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - // Close the connection, wait for completion. - ep.Close() - - // Wait for ep to become writable. - <-notifyCh - - // Call Connect again to retreive the handshake failure status - // and stats updates. - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } -} - -// Test for ICMP error handling without completing handshake. -func TestConnectICMPError(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventHUp) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - syn := c.GetPacket() - checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) - - wep := ep.(interface { - StopWork() - ResumeWork() - LastErrorLocked() tcpip.Error - }) - - // Stop the protocol loop, ensure that the ICMP error is processed and - // the last ICMP error is read before the loop is resumed. This sanity - // tests the handshake completion logic on ICMP errors. - wep.StopWork() - - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) - - for { - if err := wep.LastErrorLocked(); err != nil { - if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { - t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) - } - break - } - time.Sleep(time.Millisecond) - } - - wep.ResumeWork() - - <-notifyCh - - // The stack would have unregistered the endpoint because of the ICMP error. - // Expect a RST for any subsequent packets sent to the endpoint. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, - AckNum: c.IRS + 1, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) - } -} - -func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) - } -} - -func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - { - err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { - t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) - } -} - -func TestCloseWithoutConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - c.EP.Close() - - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) - } -} - -func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.GetPacket() - - metricPollFn := func() error { - if got := stats.TCP.ResetsSent.Value(); got != want { - return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestTCPResetsSentNoICMP confirms that we don't get an ICMP -// DstUnreachable packet when we try send a packet which is not part -// of an active session. -func TestTCPResetsSentNoICMP(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - - // Send a SYN request for a closed port. This should elicit an RST - // but NOT an ICMPv4 DstUnreachable packet. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive whatever comes back. - b := c.GetPacket() - ipHdr := header.IPv4(b) - if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { - t.Errorf("unexpected protocol, got = %d, want = %d", got, want) - } - - // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4.PacketsSent - if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { - t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) - } -} - -// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates -// a RST if an ACK is received on the listening socket for which there is no -// active handshake in progress and we are not using SYN cookies. -func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Lower stackwide TIME_WAIT timeout so that the reservations - // are released instantly on Close. - tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - c.GetPacket() - - // Since an active close was done we need to wait for a little more than - // tcpLingerTimeout for the port reservations to be released and the - // socket to move to a CLOSED state. - time.Sleep(20 * time.Millisecond) - - // Now resend the same ACK, this ACK should generate a RST as there - // should be no endpoint in SYN-RCVD state and we are not using - // syn-cookies yet. The reason we send the same ACK is we need a valid - // cookie(IRS) generated by the netstack without which the ACK will be - // rejected. - c.SendPacket(nil, ackHeaders) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } -} - -func TestTCPResetsDoNotGenerateResets(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } - c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) -} - -func TestActiveHandshake(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) -} - -func TestNonBlockingClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } -} - -func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLinger to 3 seconds so that sockets are marked closed - // after 3 second in FIN_WAIT2 state. - tcpLingerTimeout := 3 * time.Second - opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Wait for the ep to give up waiting for a FIN. - time.Sleep(tcpLingerTimeout + 1*time.Second) - - // Now send an ACK and it should trigger a RST as the endpoint should - // not exist anymore. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - for { - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue - } - - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence number - // of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - break - } -} - -// TestCurrentConnectedIncrement tests increment of the current -// established and connected counters. -func TestCurrentConnectedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) - } - gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() - if gotConnected != 1 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) - } - - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) - } - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait for a little more than the TIME-WAIT duration for the socket to - // transition to CLOSED state. - time.Sleep(1200 * time.Millisecond) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -// TestClosingWithEnqueuedSegments tests handling of still enqueued segments -// when the endpoint transitions to StateClose. The in-flight segments would be -// re-enqueued to a any listening endpoint. -func TestClosingWithEnqueuedSegments(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } - - // Send a FIN for ESTABLISHED --> CLOSED-WAIT - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Get the ACK for the FIN we sent. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack a few ms to transition the endpoint out of ESTABLISHED - // state. - time.Sleep(10 * time.Millisecond) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } - - // Close the application endpoint for CLOSE_WAIT --> LAST_ACK - ep.Close() - - // Get the FIN - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Pause the endpoint`s protocolMainLoop. - ep.(interface{ StopWork() }).StopWork() - - // Enqueue last ACK followed by an ACK matching the endpoint - // - // Send Last ACK for LAST_ACK --> CLOSED - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Send a packet with ACK set, this would generate RST when - // not using SYN cookies as in this test. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss.Add(2), - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Unpause endpoint`s protocolMainLoop. - ep.(interface{ ResumeWork() }).ResumeWork() - - // Wait for the protocolMainLoop to resume and update state. - time.Sleep(10 * time.Millisecond) - - // Expect the endpoint to be closed. - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - - // Check if the endpoint was moved to CLOSED and netstack a reset in - // response to the ACK packet that we sent after last-ACK. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) -} - -func TestSimpleReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when -// creating a new active TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnect(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - connectAddr tcpip.Address - checker func(*testing.T, *context.Context, uint16, int) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - connectAddr: context.TestAddr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(true) - }, - connectAddr: context.TestV6Addr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, - }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, - }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - ip.createEP(c) - - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } - - // Get expected window size. - rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - { - err := c.EP.Connect(connectAddr) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) - } - } - - // Receive SYN packet with our user supplied MSS. - ip.checker(t, c, test.expMSS, ws) - }) - } - }) - } -} - -// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used -// when completing the handshake for a new TCP connection from a TCP -// listening socket. It should be present in the sent TCP SYN-ACK segment. -func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - sendPkt func(*context.Context, *context.Headers) - checker func(*testing.T, *context.Context, uint16, uint16) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendPacket(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(false) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendV6Packet(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, - }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, - }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - ip.createEP(c) - - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } - - bindAddr := tcpip.FullAddress{Port: context.StackPort} - if err := c.EP.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s:", bindAddr, err) - } - - backlog := 5 - // Keep the number of client requests twice to the backlog - // such that half of the connections do not use syncookies - // and the other half does. - clientConnects := backlog * 2 - - if err := c.EP.Listen(backlog); err != nil { - t.Fatalf("Listen(%d): %s:", backlog, err) - } - - for i := 0; i < clientConnects; i++ { - // Send a SYN requests. - iss := seqnum.Value(i) - srcPort := context.TestPort + uint16(i) - ip.sendPkt(c, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - ip.checker(t, c, srcPort, test.expMSS) - } - }) - } - }) - } -} -func TestSendRstOnListenerRxSynAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -func TestSendRstOnListenerRxSynAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv4. -func TestTCPAckBeforeAcceptV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv6. -func TestTCPAckBeforeAcceptV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestSendRstOnListenerRxAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -func TestSendRstOnListenerRxAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true /* v6Only */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -// TestListenShutdown tests for the listening endpoint replying with RST -// on read shutdown. -func TestListenShutdown(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatal("Shutdown failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: 100, - AckNum: 200, - }) - - // Expect the listening endpoint to reset the connection. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) -} - -var _ waiter.EntryCallback = (callback)(nil) - -type callback func(*waiter.Entry, waiter.EventMask) - -func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { - cb(entry, mask) -} - -func TestListenerReadinessOnEvent(t *testing.T) { - s := stack.New(stack.Options{ - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - { - ep := loopback.New() - if testing.Verbose() { - ep = sniffer.New(ep) - } - const id = 1 - if err := s.CreateNIC(id, ep); err != nil { - t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) - } - if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: id}, - }) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { - t.Fatalf("Bind(%s): %s", context.StackAddr, err) - } - const backlog = 1 - if err := ep.Listen(backlog); err != nil { - t.Fatalf("Listen(%d): %s", backlog, err) - } - - address, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress(): %s", err) - } - - conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) - } - defer conn.Close() - - events := make(chan waiter.EventMask) - // Scope `entry` to allow a binding of the same name below. - { - entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { - events <- ep.Readiness(mask) - })} - wq.EventRegister(&entry, waiter.EventIn) - defer wq.EventUnregister(&entry) - } - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.EventOut) - defer wq.EventUnregister(&entry) - - switch err := conn.Connect(address).(type) { - case *tcpip.ErrConnectStarted: - default: - t.Fatalf("Connect(%#v): %v", address, err) - } - - // Read at least one event. - got := <-events - for { - select { - case event := <-events: - got |= event - continue - case <-ch: - if want := waiter.ReadableEvents; got != want { - t.Errorf("observed events = %b, want %b", got, want) - } - } - break - } -} - -// TestListenCloseWhileConnect tests for the listening endpoint to -// drain the accept-queue when closed. This should reset all of the -// pending connections that are waiting to be accepted. -func TestListenCloseWhileConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - // Wait for the new endpoint created because of handshake to be delivered - // to the listening endpoint's accept queue. - <-notifyCh - - // Close the listening endpoint. - c.EP.Close() - - // Expect the listening endpoint to reset the connection. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) -} - -func TestTOSV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) - } - - v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) - } - - if v != tos { - t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) - } - - testV4Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestTrafficClassV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { - t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) - } - - v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) - } - - if v != tos { - t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) - } - - // Test the connection request. - testV6Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetV6Packet() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv6(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestConnectBindToDevice(t *testing.T) { - for _, test := range []struct { - name string - device tcpip.NICID - want tcp.EndpointState - }{ - {"RightDevice", 1, tcp.StateEstablished}, - {"WrongDevice", 2, tcp.StateSynSent}, - {"AnyDevice", 0, tcp.StateEstablished}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) - } - // Start connection attempt. - waitEntry, _ := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - - c.GetPacket() - if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - }) - } -} - -func TestSynSent(t *testing.T) { - for _, test := range []struct { - name string - reset bool - }{ - {"RstOnSynSent", true}, - {"CloseOnSynSent", false}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create an endpoint, don't handshake because we want to interfere with the - // handshake process. - c.Create(-1) - - // Start connection attempt. - waitEntry, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { - t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - if test.reset { - // Send a packet with a proper ACK and a RST flag to cause the socket - // to error and close out. - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagRst | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - } else { - c.EP.Close() - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatal("timed out waiting for packet to arrive") - } - - ept := endpointTester{c.EP} - if test.reset { - ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) - } else { - ept.CheckReadError(t, &tcpip.ErrAborted{}) - } - - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } - - // Due to the RST the endpoint should be in an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - }) - } -} - -func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(3), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that we get an ACK specifying which seqnum is expected. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive data. - read := ept.CheckReadFull(t, 6, ch, 5*time.Second) - - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that the whole data is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rcvBufSz := math.MaxUint16 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send 100 packets before the actual one that is expected. - data := []byte{1, 2, 3, 4, 5, 6} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < 100; i++ { - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(6), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Send packet with seqnum as initial + 3. It must be discarded because the - // out-of-order buffer was filled by the previous packets. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(3), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now send the expected packet with initial sequence number. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that only packet with initial sequence number is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+3), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(len(data))), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) - - // Make sure we get the FIN but DON't ACK IT. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Cause a RST to be generated by closing the read end now since we have - // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) - - // Make sure we get the RST - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence - // number of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(len(data))), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestShutdownRead(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) - } -} - -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBufSz = 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies - // the provided buffer value by tcp.SegOverheadFactor to calculate the actual - // receive buffer size. - data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) - for i := range data { - data[i] = byte(i % 255) - } - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and window goes to zero. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) - } - - // Check that we get an ACK for the newly non-zero window. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(10), - ), - ) -} - -// Test the stack receive window advertisement on receiving segments smaller than -// segment overhead. It tests for the right edge of the window to not grow when -// the endpoint is not being read from. -func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize, - Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), - } - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - - c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Bump up the receive buffer size such that, when the receive window grows, - // the scaled window exceeds maxUint16. - c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */) - - // Keep the payload size < segment overhead and such that it is a multiple - // of the window scaled value. This enables the test to perform equality - // checks on the incoming receive window. - payloadSize := 1 << c.RcvdWindowScale - if payloadSize >= tcp.SegSize { - t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize) - } - payload := generateRandomPayload(t, payloadSize) - payloadLen := seqnum.Size(len(payload)) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - - // Send payload to the endpoint and return the advertised receive window - // from the endpoint. - getIncomingRcvWnd := func() uint32 { - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss, - AckNum: c.IRS.Add(1), - Flags: header.TCPFlagAck, - RcvWnd: 30000, - }) - iss = iss.Add(payloadLen) - - pkt := c.GetPacket() - return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale - } - - // Read the advertised receive window with the ACK for payload. - rcvWnd := getIncomingRcvWnd() - - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } - - // Read the data so that the subsequent ACK from the endpoint - // grows the right edge of the window. - var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { - t.Fatalf("c.EP.Read: %s", err) - } - - // Check if we have received max uint16 as our advertised - // scaled window now after a read above. - maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) - if got, want := getIncomingRcvWnd(), maxRcv; got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } - - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } -} - -func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Start off with a certain receive buffer then cut it in half and verify that - // the right edge of the window does not shrink. - // NOTE: Netstack doubles the value specified here. - rcvBufSize := 65536 - // Enable window scaling with a scale of zero from our end. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send a 1 byte payload so that we can record the current receive window. - // Send a payload of half the size of rcvBufSize. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - payload := []byte{1} - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Read the 1 byte payload we just sent. - if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { - t.Fatalf("got data: %v, want: %v", got, want) - } - - // Verify that the ACK does not shrink the window. - pkt := c.GetPacket() - iss = iss.Add(1) - checker.IPv4(t, pkt, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Stash the initial window. - initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) - // Now shrink the receive buffer to half its original size. - c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */) - - data := generateRandomPayload(t, rcvBufSize) - // Send a payload of half the size of rcvBufSize. - c.SendPacket(data[:rcvBufSize/2], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - - // Verify that the ACK does not shrink the window. - pkt = c.GetPacket() - checker.IPv4(t, pkt, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) - if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { - t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) - } - - // Send another payload of half the size of rcvBufSize. This should fill up the - // socket receive buffer and we should see a zero window. - c.SendPacket(data[rcvBufSize/2:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the new - // receive buffer size we set after the connection was established. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), - ), - ) -} - -func TestSimpleSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestZeroWindowSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check if we got a zero-window probe. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestNonScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 - c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestNonScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN - // should not carry the window scaling option. - c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestZeroScaledWindowReceive(t *testing.T) { - // This test ensures that the endpoint sends a non-zero window size - // advertisement when the scaled window transitions from 0 to non-zero, - // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the buffer size such that a window scale of 5 will be used. - const bufSz = 65535 * 10 - const ws = uint32(5) - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - // Write chunks of 50000 bytes. - remain := 0 - sent := 0 - data := make([]byte, 50000) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Keep writing till the window drops below len(data). - for { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Don't reduce window to zero here. - if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { - remain = wnd << ws - break - } - } - - // Make the window non-zero, but the scaled window zero. - for remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Since the receive buffer is split between window advertisement and - // application data buffer the window does not always reflect the space - // available and actual space available can be a bit more than what is - // advertised in the window. - wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) - if wnd == 0 { - break - } - remain = wnd << ws - } - - // Read at least 2MSS of data. An ack should be sent in response to that. - // Since buffer space is now split in half between window and application - // data we need to read more than 1 MSS(65536) of data for a non-zero window - // update to be sent. For 1MSS worth of window to be available we need to - // read at least 128KB. Since our segments above were 50KB each it means - // we need to read at 3 packets. - w := tcpip.LimitedWriter{ - W: ioutil.Discard, - N: defaultMTU * 2, - } - for w.N != 0 { - res, err := c.EP.Read(&w, tcpip.ReadOptions{}) - t.Logf("err=%v res=%#v", err, res) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(true) - }, - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(false) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Send tcp.InitialCwnd number of segments to fill up - // InitialWindow but don't ACK. That should prevent - // anymore packets from going out. - var r bytes.Reader - for i := 0; i < tcp.InitialCwnd; i++ { - r.Reset([]byte{0}) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - // Now send the segments that should get merged as the congestion - // window is full and we won't be able to send any more packets. - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - // Check that we get tcp.InitialCwnd packets. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < tcp.InitialCwnd; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize+1), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. - RcvWnd: 30000, - }) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+11), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) - }) - } -} - -func TestDelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - c.EP.SocketOptions().SetDelayOption(true) - - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) - } -} - -func TestUndelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - c.EP.SocketOptions().SetDelayOption(true) - - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Check that data is received. - first := c.GetPacket() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[0]))) - - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - - c.EP.SocketOptions().SetDelayOption(false) - - // Check that data is received. - second := c.GetPacket() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[1]))) - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) -} - -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - test.fn(c.EP) - - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { - payloadMultiplier := 10 - dataLen := payloadMultiplier * maxPayload - data := make([]byte, dataLen) - for i := range data { - data[i] = byte(i) - } - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received in chunks. - bytesReceived := 0 - numPackets := 0 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for bytesReceived != dataLen { - b := c.GetPacket() - numPackets++ - tcpHdr := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcpHdr.Payload()) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { - t.Fatalf("got data = %v, want = %v", p, pdata) - } - bytesReceived += payloadLen - var options []byte - if c.TimeStampEnabled { - // If timestamp option is enabled, echo back the timestamp and increment - // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcpHdr.ParsedOptions() - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) - options = tsOpt[:] - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options, - }) - } - if numPackets == 1 { - t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSetTTL(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { - t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) - } - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - // Receive SYN packet. - b := c.GetPacket() - - checker.IPv4(t, b, checker.TTL(wantTTL)) - }) - } -} - -func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() - - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %s", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 3 - c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) - - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.WritableEvents) - defer c.WQ.EventUnregister(&we) - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - // Receive SYN packet. - b := c.GetPacket() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.TCPSeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - // Send SYN-ACK. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestCloseListener(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } -} - -func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { - case *tcpip.ErrWouldBlock: - <-ch - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case *tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) - } - } - - if tcp.EndpointState(c.EP.State()) != tcp.StateError { - t.Fatalf("got EP state is not StateError") - } - - checkValid := func() []error { - var errors []error - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) - } - return errors - } - - start := time.Now() - for time.Since(start) < time.Minute && len(checkValid()) > 0 { - time.Sleep(50 * time.Millisecond) - } - for _, err := range checkValid() { - t.Error(err) - } -} - -func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Wait for the RST to be received. - time.Sleep(1 * time.Second) - - // Try to write. - var r bytes.Reader - r.Reset(make([]byte, 10)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) - } -} - -// TestMaxRetransmitsTimeout tests if the connection is timed out after -// a segment has been retransmitted MaxRetries times. -func TestMaxRetransmitsTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const numRetries = 2 - opt := tcpip.TCPMaxRetriesOption(numRetries) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Expect first transmit and MaxRetries retransmits. - for i := 0; i < numRetries+1; i++ { - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - } - // Wait for the connection to timeout after MaxRetries retransmits. - initRTO := 1 * time.Second - select { - case <-notifyCh: - case <-time.After((2 << numRetries) * initRTO): - t.Fatalf("connection still alive after maximum retransmits.\n") - } - - // Send an ACK and expect a RST as the connection would have been closed. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -// TestMaxRTO tests if the retransmit interval caps to MaxRTO. -func TestMaxRTO(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rto := 1 * time.Second - opt := tcpip.TCPMaxRTOOption(rto) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - const numRetransmits = 2 - for i := 0; i < numRetransmits; i++ { - start := time.Now() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { - t.Errorf("Retransmit interval not capped to MaxRTO.\n") - } - } -} - -// TestZeroSizedWriteRetransmit tests that a zero sized write should not -// result in a panic on an RTO as no segment should have been queued for -// a zero sized write. -func TestZeroSizedWriteRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - var r bytes.Reader - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - // Now do a non-zero sized write to trigger actual sending of data. - r.Reset(make([]byte, 1)) - _, err = c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - // Do not ACK the packet and expect an original transmit and a - // retransmit. This should not cause a panic. - for i := 0; i < 2; i++ { - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - } -} - -// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is -// unique on retransmits. -func TestRetransmitIPv4IDUniqueness(t *testing.T) { - for _, tc := range []struct { - name string - size int - }{ - {"1Byte", 1}, - {"512Bytes", 512}, - } { - t.Run(tc.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - // Disabling PMTU discovery causes all packets sent from this socket to - // have DF=0. This needs to be done because the IPv4 ID uniqueness - // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 - // Section 4, and datagrams with DF=0 are non-atomic. - if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { - t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) - } - - var r bytes.Reader - r.Reset(make([]byte, tc.size)) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} - // Expect two retransmitted packets, and that all packets received have - // unique IPv4 ID values. - for i := 0; i <= 2; i++ { - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - id := header.IPv4(pkt).ID() - if _, exists := idSet[id]; exists { - t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) - } - idSet[id] = struct{}{} - } - }) - } -} - -func TestFinImmediately(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Don't acknowledge yet. We should get a retransmit of the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write something out, and have it acknowledged. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := make([]byte, 10) - var r bytes.Reader - for i := tcp.InitialCwnd; i > 0; i-- { - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := tcp.InitialCwnd; i > 0; i-- { - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that we get an ACK for the fin. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) - - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) - - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss.Add(1), - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) -} - -func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %s", err) - } - - ep.Close() -} - -func scaledSendWindow(t *testing.T, scale uint8) { - // This test ensures that the endpoint is using the right scaling by - // sending a buffer that is larger than the window size, and ensuring - // that the endpoint doesn't send more than allowed. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - header.TCPOptionWS, 3, scale, header.TCPOptionNOP, - }) - - // Open up the window with a scaled value. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 1, - }) - - // Send some data. Check that it's capped by the window size. - view := make([]byte, 65535) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that only data that fits in the scaled window is sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen((1<<scale)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Reset the connection to free resources. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - }) -} - -func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } -} - -func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) - } - // Ensure there were no errors during handshake. If these stats have - // incremented, then the connection should not have been established. - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) - } -} - -func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - vv := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] - tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - c.SendSegment(vv) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } -} - -func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 - - c.SendSegment(vv) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) - } -} - -func TestReceivedSegmentQueuing(t *testing.T) { - // This test sends 200 segments containing a few bytes each to an - // endpoint and checks that they're all received and acknowledged by - // the endpoint, that is, that none of the segments are dropped by - // internal queues. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Send 200 segments. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(i * len(data))), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } - - // Receive ACKs for all segments. - last := iss.Add(seqnum.Size(200 * len(data))) - for { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break - } - - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) - } - } -} - -func TestReadAfterClosedState(t *testing.T) { - // This test ensures that calling Read() or Peek() after the endpoint - // has transitioned to closedState still works if there is pending - // data. To transition to stateClosed without calling Close(), we must - // shutdown the send path and the peer must send its own FIN. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack the chance to transition to closed state from - // TIME_WAIT. - time.Sleep(tcpTimeWaitTimeout * 2) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that peek works. - var peekBuf bytes.Buffer - res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) - if err != nil { - t.Fatalf("Peek failed: %s", err) - } - - if got, want := res.Count, len(data); got != want { - t.Fatalf("res.Count = %d, want %d", got, want) - } - if !bytes.Equal(data, peekBuf.Bytes()) { - t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) - } - - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Now that we drained the queue, check that functions fail with the - // right error code. - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var buf bytes.Buffer - { - _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { - t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) - } - } -} - -func TestReusePort(t *testing.T) { - // This test ensures that ports are immediately available for reuse - // after Close on the endpoints using them returns. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // First case, just an endpoint that was bound. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() - - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() - - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } -} - -func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s := ep.SocketOptions().GetReceiveBufferSize() - if int(s) != v { - t.Fatalf("got receive buffer size = %d, want = %d", s, v) - } -} - -func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v { - t.Fatalf("got send buffer size = %d, want = %d", s, v) - } -} - -func TestDefaultBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer func() { - if ep != nil { - ep.Close() - } - }() - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default send buffer size. - { - opt := tcpip.TCPSendBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default receive buffer size. - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) - - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - if err := s.CreateNIC(321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %s", err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) - } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) - } - }) - } -} - -func makeStack() (*stack.Stack, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - id := loopback.New() - if testing.Verbose() { - id = sniffer.New(id) - } - - if err := s.CreateNIC(1, id); err != nil { - return nil, err - } - - for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address - }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, - } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { - return nil, err - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return s, nil -} - -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.WritableEvents) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - <-notifyCh - if err := ep.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } - - // Write something. - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Read back what was written. - wq.EventUnregister(&waitEntry) - wq.EventRegister(&waitEntry, waiter.ReadableEvents) - ept := endpointTester{ep} - rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) - - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } -} - -func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } - - panic("unreachable") - } - - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return "" - } - return context.StackAddr - case "v6": - if isAny { - return "" - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr - } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) - } - - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() - } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - ep.SocketOptions().SetV6Only(true) - case "dual": - ep.SocketOptions().SetV6Only(false) - default: - t.Fatalf("unknown network: '%s'", network) - } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) - } - } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - const ( - start = 16000 - end = 16050 - ) - if err := s.SetPortRange(start, end); err != nil { - t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) - } - for i := start; i <= end; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %s", i, err) - } - } - var want tcpip.Error = &tcpip.ErrConnectStarted{} - if collides { - want = &tcpip.ErrNoPortAvailable{} - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } -} - -func TestPathMTUDiscovery(t *testing.T) { - // This test verifies the stack retransmits packets after it receives an - // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := make([]byte, writeSize) - for i := range data { - data[i] = byte(i) - } - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { - var ret []byte - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(seqNum), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := []byte{0, 0, newMTU / 256, newMTU % 256} - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) - - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} - -func TestTCPEndpointProbe(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - invoked := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) - } - if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) - } - - invoked <- struct{}{} - }) - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } -} - -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Query permitted congestion control algorithms. - var aCC tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.TCPAvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) - } - - // Verify that we still get the expected list of congestion control options. - var cc tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) - } - if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) - } -} - -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, connected := range []bool{false, true} { - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) - } - - if connected { - c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) - } - - if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { - t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) - } - - got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control = %+v, want = %+v", got, want) - } - }) - } - } -} - -func enableCUBIC(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) - } -} - -func TestKeepalive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) - } - keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) - } - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < 10; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next-1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) - - // The connection should be terminated after 5 unacked keepalives. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - t.Helper() - // Send a SYN request. - irs = seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), - })) - } - - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - t.Helper() - // Send a SYN request. - irs = seqnum.Value(context.TestInitialSequenceNumber) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptionsV6(), - })) - } - - checker.IPv6(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -// TestListenBacklogFull tests that netstack does not complete handshakes if the -// listen backlog for the endpoint is full. -func TestListenBacklogFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 10 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - } - - time.Sleep(50 * time.Millisecond) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + lastPortOffset, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - - newEP, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } -} - -// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a -// non unicast IPv4 address are not accepted. -func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") - otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") - subnet := context.StackAddrWithPrefix.Subnet() - subnetBroadcastAddr := subnet.Broadcast() - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - name: "SourceUnspecified", - srcAddr: header.IPv4Any, - dstAddr: context.StackAddr, - }, - { - name: "SourceBroadcast", - srcAddr: header.IPv4Broadcast, - dstAddr: context.StackAddr, - }, - { - name: "SourceOurMulticast", - srcAddr: multicastAddr, - dstAddr: context.StackAddr, - }, - { - name: "SourceOtherMulticast", - srcAddr: otherMulticastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestUnspecified", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Any, - }, - { - name: "DestBroadcast", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Broadcast, - }, - { - name: "DestOurMulticast", - srcAddr: context.TestAddr, - dstAddr: multicastAddr, - }, - { - name: "DestOtherMulticast", - srcAddr: context.TestAddr, - dstAddr: otherMulticastAddr, - }, - { - name: "SrcSubnetBroadcast", - srcAddr: subnetBroadcastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestSubnetBroadcast", - srcAddr: context.TestAddr, - dstAddr: subnetBroadcastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestAddr, context.StackAddr) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } -} - -// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a -// non unicast IPv6 address are not accepted. -func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpiptestutil.MustParse6("ff0e::101") - otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv6Any, - context.StackV6Addr, - }, - { - "SourceAllNodes", - header.IPv6AllNodesMulticastAddress, - context.StackV6Addr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackV6Addr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackV6Addr, - }, - { - "DestUnspecified", - context.TestV6Addr, - header.IPv6Any, - }, - { - "DestAllNodes", - context.TestV6Addr, - header.IPv6AllNodesMulticastAddress, - }, - { - "DestOurMulticast", - context.TestV6Addr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestV6Addr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestV6Addr, context.StackV6Addr) - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } -} - -func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the accept queue is full. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now complete the previous connection. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Verify if that is delivered to the accept queue. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - <-ch - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - newEP, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - pkt := c.GetPacket() - tcp = header.IPv4(pkt).Payload() - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } -} - -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test for SynCookies usage after filling up the backlog. - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - executeHandshake(t, c, context.TestPort, false) - - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - - // Send a SYN request. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - // pick a different src port for new SYN. - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } -} - -func TestSYNRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send the same SYN packet multiple times. We should still get a valid SYN-ACK - // reply. - irs := seqnum.Value(context.TestInitialSequenceNumber) - for i := 0; i < 5; i++ { - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Receive the SYN-ACK reply. - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) -} - -func TestSynRcvdBadSeqNumber(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now send a packet with an out-of-window sequence number - largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: largeSeqnum, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Should receive an ACK with the expected SEQ number - b = c.GetPacket() - tcpCheckers = []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPAckNum(uint32(irs) + 1), - checker.TCPSeqNum(uint32(iss + 1)), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now that the socket replied appropriately with the ACK, - // complete the connection to test that the large SEQ num - // did not change the state from SYN-RCVD. - - // Get setup to be notified about connection establishment. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Send ACK to move to ESTABLISHED state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - <-ch - newEP, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - pkt := c.GetPacket() - tcpHdr = header.IPv4(pkt).Payload() - if string(tcpHdr.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) - } -} - -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, false) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, false) - - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) - - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber), - RcvWnd: 30000, - }) - - checkValid := func() []error { - var errors []error - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) - } - return errors - } - - start := time.Now() - for time.Since(start) < time.Minute && len(checkValid()) > 0 { - time.Sleep(50 * time.Millisecond) - } - for _, err := range checkValid() { - t.Error(err) - } - if t.Failed() { - t.FailNow() - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - <-ch - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - } -} - -func TestListenDropIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - c.Create(-1 /*epRcvBuf*/) - - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(1 /*backlog*/); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - initialDropped := stats.DroppedPackets.Value() - - // Send RST, FIN segments, that are expected to be dropped by the listener. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - }) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin, - }) - - // To ensure that the RST, FIN sent earlier are indeed received and ignored - // by the listener, send a SYN and wait for the SYN to be ACKd. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - }) - checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - )) - - if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { - t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) - } -} - -func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ept := endpointTester{ep} - ept.CheckReadError(t, &tcpip.ErrNotConnected{}) - if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { - t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - aep, _, err := ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - { - err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { - t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) - } - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size defined above. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - // NOTE: The timestamp values in the sent packets are meaningless to the - // peer so we just increment the timestamp value by 1 every batch as we - // are not really using them for anything. Send a single byte to verify - // the advertised window. - tsVal := rawEP.TSVal + 1 - - // Introduce a 25ms latency by delaying the first byte. - latency := 25 * time.Millisecond - time.Sleep(latency) - // Send an initial payload with atleast segment overhead size. The receive - // window would not grow for smaller segments. - rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) - - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() - - time.Sleep(25 * time.Millisecond) - - // Allocate a large enough payload for the test. - payloadSize := receiveBufferSize * 2 - b := make([]byte, payloadSize) - - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := 0 - end := payloadSize / 2 - packetsSent := 0 - for ; start < end; start += mss { - packetEnd := start + mss - if start+mss > end { - packetEnd = end - } - rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) - packetsSent++ - } - - // Resume the worker so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Since we sent almost the full receive buffer worth of data (some may have - // been dropped due to segment overheads), we should get a zero window back. - pkt = c.GetPacket() - tcpHdr := header.TCP(header.IPv4(pkt).Payload()) - gotRcvWnd := tcpHdr.WindowSize() - wantAckNum := tcpHdr.AckNumber() - if got, want := int(gotRcvWnd), 0; got != want { - t.Fatalf("got rcvWnd: %d, want: %d", got, want) - } - - time.Sleep(25 * time.Millisecond) - // Verify that sending more data when receiveBuffer is exhausted. - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - - // Now read all the data from the endpoint and verify that advertised - // window increases to the full available buffer size. - for { - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break - } - } - - // Verify that we receive a non-zero window update ACK. When running - // under thread santizer this test can end up sending more than 1 - // ack, 1 for the non-zero window - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(wantAckNum), - func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - // We use 10% here as the error margin upwards as the initial window we - // got was afer 1 segment was already in the receive buffer queue. - tolerance := 1.1 - if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { - t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) - } - }, - )) -} - -// This test verifies that the advertised window is auto-tuned up as the -// application is reading the data that is being received. -func TestReceiveBufferAutoTuning(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - // Enable Auto-tuning. - stk := c.Stack() - // Disable out of window rate limiting for this test by setting it to 0 as we - // use out of window ACKs to measure the advertised window. - var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption - if err := stk.SetOption(tcpInvalidRateLimit); err != nil { - t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) - } - - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size used by stack. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := rawEP.TSVal - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale - scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> c.WindowScale) - } - // Allocate a large array to send to the endpoint. - b := make([]byte, receiveBufferSize*48) - - // In every iteration we will send double the number of bytes sent in - // the previous iteration and read the same from the app. The received - // window should grow by at least 2x of bytes read by the app in every - // RTT. - offset := 0 - payloadSize := receiveBufferSize / 8 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - latency := 1 * time.Millisecond - for i := 0; i < 5; i++ { - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - totalSent := 0 - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - totalSent += mss - packetsSent++ - } - - // Resume it so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Give 1ms for the worker to process the packets. - time.Sleep(1 * time.Millisecond) - - lastACK := c.GetPacket() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. - for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { - break - } - lastACK = pkt - } - if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { - t.Fatalf("advertised window got: %d, want <= %d", got, want) - } - - // Now read all the data from the endpoint and invoke the - // moderation API to allow for receive buffer auto-tuning - // to happen before we measure the new window. - totalCopied := 0 - for { - res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break - } - totalCopied += res.Count - } - - // Invoke the moderation API. This is required for auto-tuning - // to happen. This method is normally expected to be invoked - // from a higher layer than tcpip.Endpoint. So we simulate - // copying to userspace by invoking it explicitly here. - c.EP.ModerateRecvBuf(totalCopied) - - // Now send a keep-alive packet to trigger an ACK so that we can - // measure the new window. - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - - if i == 0 { - // In the first iteration the receiver based RTT is not - // yet known as a result the moderation code should not - // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) - } else { - // Read loop above could generate an ACK if the window had dropped to - // zero and then read had opened it up. - lastACK := c.GetPacket() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. - for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { - break - } - lastACK = pkt - } - curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale - // If thew new current window is close maxReceiveBufferSize then terminate - // the loop. This can happen before all iterations are done due to timing - // differences when running the test. - if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { - break - } - // Increase the latency after first two iterations to - // establish a low RTT value in the receiver since it - // only tracks the lowest value. This ensures that when - // ModerateRcvBuf is called the elapsed time is always > - // rtt. Without this the test is flaky due to delays due - // to scheduling/wakeup etc. - latency += 50 * time.Millisecond - } - time.Sleep(latency) - offset += payloadSize - payloadSize *= 2 - } - // Check that at the end of our iterations the receive window grew close to the maximum - // permissible size of maxReceiveBufferSize/2 - if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { - t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) - } - -} - -func TestDelayEnabled(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - checkDelayOption(t, c, false, false) // Delay is disabled by default. - - for _, delayEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - opt := tcpip.TCPDelayEnabled(delayEnabled) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) - } - checkDelayOption(t, c, opt, delayEnabled) - }) - } -} - -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { - t.Helper() - - var gotDelayEnabled tcpip.TCPDelayEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { - t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) - } - if gotDelayEnabled != wantDelayEnabled { - t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) - } - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) - if err != nil { - t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) - } - gotDelayOption := ep.SocketOptions().GetDelayOption() - if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) - } -} - -func TestTCPLingerTimeout(t *testing.T) { - c := context.New(t, 1500 /* mtu */) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - testCases := []struct { - name string - tcpLingerTimeout time.Duration - want time.Duration - }{ - {"NegativeLingerTimeout", -123123, -1}, - // Zero is treated same as the stack's default TCP_LINGER2 timeout. - {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, - {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, - // Values > stack's TCPLingerTimeout are capped to the stack's - // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) - } - - v = 0 - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(&%T) = %s", v, err) - } - if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("got linger timeout = %s, want = %s", got, want) - } - }) - } -} - -func TestTCPTimeWaitRSTIgnored(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now send a RST and this should be ignored and not - // generate an ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - }) - - c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitOutOfOrder(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitNewSyn(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Send a SYN request w/ sequence number lower than - // the highest sequence number sent. We just reuse - // the same number. - iss = seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) - - // drain any older notifications from the notification channel before attempting - // 2nd connection. - select { - case <-ch: - default: - } - - // Send a SYN request w/ sequence number higher than - // the highest sequence number sent. - iss = iss.Add(3) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b = c.GetPacket() - tcpHdr = header.IPv4(b).Payload() - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } - - want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - time.Sleep(2 * time.Second) - - // Now send a duplicate FIN. This should cause the TIME_WAIT to extend - // by another 5 seconds and also send us a duplicate ACK as it should - // indicate that the final ACK was potentially lost. - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Sleep for 4 seconds so at this point we are 1 second past the - // original tcpLingerTimeout of 5 seconds. - time.Sleep(4 * time.Second) - - // Send an ACK and it should not generate any packet as the socket - // should still be in TIME_WAIT for another another 5 seconds due - // to the duplicate FIN we sent earlier. - *ackHeaders = *finHeaders - ackHeaders.SeqNum = ackHeaders.SeqNum + 1 - ackHeaders.Flags = header.TCPFlagAck - c.SendPacket(nil, ackHeaders) - - c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) - // Now sleep for another 2 seconds so that we are past the - // extended TIME_WAIT of 7 seconds (2 + 5). - time.Sleep(2 * time.Second) - - // Resend the same ACK. - c.SendPacket(nil, ackHeaders) - - // Receive the RST that should be generated as there is no valid - // endpoint. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } -} - -func TestTCPCloseWithData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: 30000, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now trigger a passive close by sending a FIN. - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - RcvWnd: 30000, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now write a few bytes and then close the endpoint. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } - - c.EP.Close() - // Check the FIN. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.TCPAckNum(uint32(iss+2)), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - // First send a partial ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now send a full ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now ACK the FIN. - ackHeaders.AckNum++ - c.SendPacket(nil, ackHeaders) - - // Now send an ACK and we should get a RST back as the endpoint should - // be in CLOSED state. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Check the RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - // Ensure that on the next retransmit timer fire, the user timeout has - // expired. - initRTO := 1 * time.Second - userTimeout := initRTO / 2 - v := tcpip.TCPUserTimeoutOption(userTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) - } - - // Send some data and wait before ACKing it. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Wait for the retransmit timer to be fired and the user timeout to cause - // close of the connection. - select { - case <-notifyCh: - case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) - } - - // No packet should be received as the connection should be silently - // closed due to timeout. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") - - next += uint32(len(view)) - - // The connection should be terminated after userTimeout has expired. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestKeepaliveWithUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) - } - keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) - } - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // Set userTimeout to be the duration to be 1 keepalive - // probes. Which means that after the first probe is sent - // the second one should cause the connection to be - // closed due to userTimeout being hit. - userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&userTimeout); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Now receive 1 keepalives, but don't ACK it. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) - - // The connection should be closed with a timeout. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS + 1, - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestIncreaseWindowOnRead(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after read() when the window grows by more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf * 2 - sent := 0 - data := make([]byte, defaultMTU/2) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Break once the window drops below defaultMTU/2 - if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { - break - } - } - - // We now have < 1 MSS in the buffer space. Read at least > 2 MSS - // worth of data as receive buffer space - w := tcpip.LimitedWriter{ - W: ioutil.Discard, - // defaultMTU is a good enough estimate for the MSS used for this - // connection. - N: defaultMTU * 2, - } - for w.N != 0 { - _, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - // After reading > MSS worth of data, we surely crossed MSS. See the ack: - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestIncreaseWindowOnBufferResize(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after available recv buffer grows to more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf - sent := 0 - data := make([]byte, defaultMTU/2) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindowLessThanEq(0xffff), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Increasing the buffer from should generate an ACK, - // since window grew from small value to larger equal MSS - c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestTCPDeferAccept(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - _, _, err := c.EP.Accept(nil) - if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { - t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) - } - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give a bit of time for the socket to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestTCPDeferAcceptTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - _, _, err := c.EP.Accept(nil) - if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { - t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) - } - - // Sleep for a little of the tcpDeferAccept timeout. - time.Sleep(tcpDeferAccept + 100*time.Millisecond) - - // On timeout expiry we should get a SYN-ACK retransmission. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give sometime for the endpoint to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestResetDuringClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) - // Send some data to make sure there is some unread - // data to trigger a reset on c.Close. - irs := c.IRS - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: irs.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(irs.Add(1))), - checker.TCPAckNum(uint32(iss)+4))) - - // Close in a separate goroutine so that we can trigger - // a race with the RST we send below. This should not - // panic due to the route being released depeding on - // whether Close() sends an active RST or the RST sent - // below is processed by the worker first. - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(4), - AckNum: c.IRS.Add(5), - RcvWnd: 30000, - Flags: header.TCPFlagRst, - }) - }() - - wg.Add(1) - go func() { - defer wg.Done() - c.EP.Close() - }() - - wg.Wait() -} - -func TestStackTimeWaitReuse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - s := c.Stack() - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) - } - if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } -} - -func TestSetStackTimeWaitReuse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - s := c.Stack() - testCases := []struct { - v int - err tcpip.Error - }{ - {int(tcpip.TCPTimeWaitReuseDisabled), nil}, - {int(tcpip.TCPTimeWaitReuseGlobal), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, - } - - for _, tc := range testCases { - opt := tcpip.TCPTimeWaitReuseOption(tc.v) - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) - if got, want := err, tc.err; got != want { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) - } - if tc.err != nil { - continue - } - - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) - } - - if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } - } -} - -// generateRandomPayload generates a random byte slice of the specified length -// causing a fatal test failure if it is unable to do so. -func generateRandomPayload(t *testing.T, n int) []byte { - t.Helper() - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - t.Fatalf("rand.Read(buf) failed: %s", err) - } - return buf -} - -func TestSendBufferTuning(t *testing.T) { - const maxPayload = 536 - const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload - const packetOverheadFactor = 2 - - testCases := []struct { - name string - autoTuningDisabled bool - }{ - {"autoTuningDisabled", true}, - {"autoTuningEnabled", false}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - // Set the stack option for send buffer size. - const defaultSndBufSz = maxPayload * tcp.InitialCwnd - const maxSndBufSz = defaultSndBufSz * 10 - { - opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz} - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - oldSz := c.EP.SocketOptions().GetSendBufferSize() - if oldSz != defaultSndBufSz { - t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz) - } - - if tc.autoTuningDisabled { - c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */) - } - - data := make([]byte, maxPayload) - for i := range data { - data[i] = byte(i) - } - - w, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&w, waiter.WritableEvents) - defer c.WQ.EventUnregister(&w) - - bytesRead := 0 - for { - // Packets will be sent till the send buffer - // size is reached. - var r bytes.Reader - r.Reset(data[bytesRead : bytesRead+maxPayload]) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break - } - - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) - bytesRead += maxPayload - data = append(data, data...) - } - - // Send an ACK and wait for connection to become writable again. - c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Write failed: %s", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - outSz := int64(defaultSndBufSz) - if !tc.autoTuningDisabled { - // Calculate the new auto tuned send buffer. - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - outSz = (int64(info.SndCwnd) * packetOverheadFactor * (maxPayload)) - } - - if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz { - t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz) - } - }) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go deleted file mode 100644 index 1deb1fe4d..000000000 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -// createConnectedWithTimestampOption creates and connects c.ep with the -// timestamp option enabled. -func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) -} - -// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on -// an active connect and sets the TS Echo Reply fields correctly when the -// SYN-ACK also indicates support for the TS option and provides a TSVal. -func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - buf := make([]byte, len(data)) - w := tcpip.SliceWriter(buf) - result, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(buf), - Total: len(buf), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf, data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } - } -} - -// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an -// active connect but if the SYN-ACK doesn't specify the TS option then -// timestamp option is not enabled and future packets do not contain a -// timestamp. -func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithOptions(header.TCPSynOptions{}) -} - -func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) - - // Now send some data and validate that timestamp is echoed correctly in the ACK. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %s", err) - } - - // Check that data is received and that the timestamp option TSEcr field - // matches the expected value. - b := c.GetPacket() - checker.IPv4(t, b, - // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 - // byte boundary. - checker.PayloadLen(len(data)+header.TCPMinimumSize+12), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - checker.TCPTimestampChecker(true, 0, tsVal+1), - ), - ) -} - -// TestTimeStampEnabledAccept tests that if the SYN on a passive connect -// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK -// and echoes the tsVal field of the original SYN in the tcEcr field of the -// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify -// that Timestamp option is enabled in both cases if requested in the original -// SYN. -func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now send some data with the accepted connection endpoint and validate - // that no timestamp option is sent in the TCP segment. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %s", err) - } - - // Check that data is received and that the timestamp option is disabled - // when SYN cookies are enabled/disabled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - checker.TCPTimestampChecker(false, 0, 0), - ), - ) -} - -// TestTimeStampDisabledAccept tests that Timestamp option is not used when the -// peer doesn't advertise it and connection is established with Accept(). -func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of - // that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } - - // Issue a read and we should data. - var buf bytes.Buffer - result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go new file mode 100644 index 000000000..4cb82fcc9 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcp diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD deleted file mode 100644 index ce6a2c31d..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - testonly = 1, - srcs = ["context.go"], - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go deleted file mode 100644 index 96e4849d2..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ /dev/null @@ -1,1233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provides a test context for use in tcp tests. It also -// provides helper methods to assert/check certain behaviours. -package context - -import ( - "bytes" - "context" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // StackAddr is the IPv4 address assigned to the stack. - StackAddr = "\x0a\x00\x00\x01" - - // StackPort is used as the listening port in tests for passive - // connects. - StackPort = 1234 - - // TestAddr is the source address for packets sent to the stack via the - // link layer endpoint. - TestAddr = "\x0a\x00\x00\x02" - - // TestPort is the TCP port used for packets sent to the stack - // via the link layer endpoint. - TestPort = 4096 - - // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - - // TestV6Addr is the source address for packets sent to the stack via - // the link layer endpoint. - TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - // StackV4MappedAddr is StackAddr as a mapped v6 address. - StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr - - // TestV4MappedAddr is TestAddr as a mapped v6 address. - TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr - - // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - // TestInitialSequenceNumber is the initial sequence number sent in packets that - // are sent in response to a SYN or in the initial SYN sent to the stack. - TestInitialSequenceNumber = 789 -) - -// StackAddrWithPrefix is StackAddr with its associated prefix length. -var StackAddrWithPrefix = tcpip.AddressWithPrefix{ - Address: StackAddr, - PrefixLen: 24, -} - -// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. -var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: StackV6Addr, - PrefixLen: header.IIDOffsetInIPv6Address * 8, -} - -// Headers is used to represent the TCP header fields when building a -// new packet. -type Headers struct { - // SrcPort holds the src port value to be used in the packet. - SrcPort uint16 - - // DstPort holds the destination port value to be used in the packet. - DstPort uint16 - - // SeqNum is the value of the sequence number field in the TCP header. - SeqNum seqnum.Value - - // AckNum represents the acknowledgement number field in the TCP header. - AckNum seqnum.Value - - // Flags are the TCP flags in the TCP header. - Flags header.TCPFlags - - // RcvWnd is the window to be advertised in the ReceiveWindow field of - // the TCP header. - RcvWnd seqnum.Size - - // TCPOpts holds the options to be sent in the option field of the TCP - // header. - TCPOpts []byte -} - -// Options contains options for creating a new test context. -type Options struct { - // EnableV4 indicates whether IPv4 should be enabled. - EnableV4 bool - - // EnableV6 indicates whether IPv4 should be enabled. - EnableV6 bool - - // MTU indicates the maximum transmission unit on the link layer. - MTU uint32 -} - -// Context provides an initialized Network stack and a link layer endpoint -// for use in TCP tests. -type Context struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - // IRS holds the initial sequence number in the SYN sent by endpoint in - // case of an active connect or the sequence number sent by the endpoint - // in the SYN-ACK sent in response to a SYN when listening in passive - // mode. - IRS seqnum.Value - - // Port holds the port bound by EP below in case of an active connect or - // the listening port number in case of a passive connect. - Port uint16 - - // EP is the test endpoint in the stack owned by this context. This endpoint - // is used in various tests to either initiate an active connect or is used - // as a passive listening endpoint to accept inbound connections. - EP tcpip.Endpoint - - // Wq is the wait queue associated with EP and is used to block for events - // on EP. - WQ waiter.Queue - - // TimeStampEnabled is true if ep is connected with the timestamp option - // enabled. - TimeStampEnabled bool - - // WindowScale is the expected window scale in SYN packets sent by - // the stack. - WindowScale uint8 - - // RcvdWindowScale is the actual window scale sent by the stack in - // SYN/SYN-ACK. - RcvdWindowScale uint8 -} - -// New allocates and initializes a test context containing a new -// stack and a link-layer endpoint. -func New(t *testing.T, mtu uint32) *Context { - return NewWithOpts(t, Options{ - EnableV4: true, - EnableV6: true, - MTU: mtu, - }) -} - -// NewWithOpts allocates and initializes a test context containing a new -// stack and a link-layer endpoint with specific options. -func NewWithOpts(t *testing.T, opts Options) *Context { - if opts.MTU == 0 { - panic("MTU must be greater than 0") - } - - stackOpts := stack.Options{ - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - } - if opts.EnableV4 { - stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) - } - if opts.EnableV6 { - stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) - } - s := stack.New(stackOpts) - - const sendBufferSize = 1 << 20 // 1 MiB - const recvBufferSize = 1 << 20 // 1 MiB - // Allow minimum send/receive buffer sizes to be 1 during tests. - sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) - } - - rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) - } - - // Increase minimum RTO in tests to avoid test flakes due to early - // retransmit in case the test executors are overloaded and cause timers - // to fire earlier than expected. - minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - - // Some of the congestion control tests send up to 640 packets, we so - // set the channel size to 1000. - ep := channel.New(1000, opts.MTU, "") - wep := stack.LinkEndpoint(ep) - if testing.Verbose() { - wep = sniffer.New(ep) - } - nicOpts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) - if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) - } - opts2 := stack.NICOptions{Name: "nic2"} - if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) - } - - var routeTable []tcpip.Route - - if opts.EnableV4 { - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - routeTable = append(routeTable, tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }) - } - - if opts.EnableV6 { - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } - routeTable = append(routeTable, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - NIC: 1, - }) - } - - s.SetRouteTable(routeTable) - - return &Context{ - t: t, - s: s, - linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), - } -} - -// Cleanup closes the context endpoint if required. -func (c *Context) Cleanup() { - if c.EP != nil { - c.EP.Close() - } - c.Stack().Close() -} - -// Stack returns a reference to the stack in the Context. -func (c *Context) Stack() *stack.Stack { - return c.s -} - -// CheckNoPacketTimeout verifies that no packet is received during the time -// specified by wait. -func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), wait) - defer cancel() - if _, ok := c.linkEP.ReadContext(ctx); ok { - c.t.Fatal(errMsg) - } -} - -// CheckNoPacket verifies that no packet is received for 1 second. -func (c *Context) CheckNoPacket(errMsg string) { - c.CheckNoPacketTimeout(errMsg, 1*time.Second) -} - -// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. If no packet is received in the specified timeout it will return -// nil. -func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - // Just check that the stack set the transport protocol number for outbound - // TCP messages. - // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part - // of the headerinfo. - if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { - c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize) - } - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// GetPacket reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. -func (c *Context) GetPacket() []byte { - c.t.Helper() - - p := c.GetPacketWithTimeout(5 * time.Second) - if p == nil { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - return p -} - -// GetPacketNonBlocking reads a packet from the link layer endpoint -// and verifies that it is an IPv4 packet with the expected source -// and destination address. If no packet is available it will return -// nil immediately. -func (c *Context) GetPacketNonBlocking() []byte { - c.t.Helper() - - p, ok := c.linkEP.Read() - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - // Just check that the stack set the transport protocol number for outbound - // TCP messages. - // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part - // of the headerinfo. - if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { - c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) { - // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) - if len(buf) > maxTotalSize { - buf = buf[:maxTotalSize] - } - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) - icmp.SetType(typ) - icmp.SetCode(code) - const icmpv4VariableHeaderOffset = 4 - copy(icmp[icmpv4VariableHeaderOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset:], p2) - icmp.SetChecksum(0) - checksum := ^header.Checksum(icmp, 0 /* initial */) - icmp.SetChecksum(checksum) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// BuildSegment builds a TCP segment based on the given Headers and payload. -func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { - return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) -} - -// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, -// payload and source and destination IPv4 addresses. -func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv4MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: h.Flags, - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - return buf.ToVectorisedView() -} - -// SendSegment sends a TCP segment that has already been built and written to a -// buffer.VectorisedView. -func (c *Context) SendSegment(s buffer.VectorisedView) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: s, - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: c.BuildSegment(payload, h), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload -// & TCPheaders) in an IPv4 packet via the link layer endpoint using the -// provided source and destination IPv4 addresses. -func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: c.BuildSegmentWithAddrs(payload, h, src, dst), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendAck sends an ACK packet. -func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { - c.SendAckWithSACK(seq, bytesReceived, nil) -} - -// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. -func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { - options := make([]byte, 40) - offset := 0 - if len(sackBlocks) > 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) - } - - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options[:offset], - }) -} - -// ReceiveAndCheckPacket reads a packet from the link layer endpoint and -// verifies that the packet packet payload of packet matches the slice -// of data indicated by offset & size. -func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { - c.t.Helper() - - c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) -} - -// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size and skips optlen bytes in addition to the IP -// TCP headers when comparing the data. -func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { - c.t.Helper() - - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize+optlen), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } -} - -// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size. It returns true if a packet was received and -// processed. -func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { - c.t.Helper() - - b := c.GetPacketNonBlocking() - if b == nil { - return false - } - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } - return true -} - -// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only -// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 -// only endpoint instead of a default dual stack socket. -func (c *Context) CreateV6Endpoint(v6only bool) { - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - c.EP.SocketOptions().SetV6Only(v6only) -} - -// GetV6Packet reads a single packet from the link layer endpoint of the context -// and asserts that it is an IPv6 Packet with the expected src/dest addresses. -func (c *Context) GetV6Packet() []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b -} - -// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of -// the context. -func (c *Context) SendV6Packet(payload []byte, h *Headers) { - c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) -} - -// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer -// endpoint of the context using the provided source and destination IPv6 -// addresses. -func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - TransportProtocol: tcp.ProtocolNumber, - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, - }) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv6MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: header.TCPMinimumSize, - Flags: h.Flags, - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) -} - -// CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { - c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) -} - -// Connect performs the 3-way handshake for c.EP with the provided Initial -// Sequence Number (iss) and receive window(rcvWnd) and any options if -// specified. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -// -// PreCondition: c.EP must already be created. -func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { - c.t.Helper() - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - c.t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - c.SendPacket(nil, &Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: options, - }) - - // Receive ACK packet. - checker.IPv4(c.t, c.GetPacket(), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.RcvdWindowScale = uint8(synOpts.WS) - c.Port = tcpHdr.SourcePort() -} - -// Create creates a TCP endpoint. -func (c *Context) Create(epRcvBuf int) { - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != -1 { - c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf)*2, true /* notify */) - } -} - -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { - c.Create(epRcvBuf) - c.Connect(iss, rcvWnd, options) -} - -// RawEndpoint is just a small wrapper around a TCP endpoint's state to make -// sending data and ACK packets easy while being able to manipulate the sequence -// numbers and timestamp values as needed. -type RawEndpoint struct { - C *Context - SrcPort uint16 - DstPort uint16 - Flags header.TCPFlags - NextSeqNum seqnum.Value - AckNum seqnum.Value - WndSize seqnum.Size - RecentTS uint32 // Stores the latest timestamp to echo back. - TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. - - // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. - SACKPermitted bool -} - -// SendPacketWithTS embeds the provided tsVal in the Timestamp option -// for the packet to be sent out. -func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { - r.TSVal = tsVal - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) - r.SendPacket(payload, tsOpt[:]) -} - -// SendPacket is a small wrapper function to build and send packets. -func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { - packetHeaders := &Headers{ - SrcPort: r.SrcPort, - DstPort: r.DstPort, - Flags: r.Flags, - SeqNum: r.NextSeqNum, - AckNum: r.AckNum, - RcvWnd: r.WndSize, - TCPOpts: opts, - } - r.C.SendPacket(payload, packetHeaders) - r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) -} - -// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches -// the provided tsVal as well as returns the original packet. -func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { - r.C.t.Helper() - // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPTimestampChecker(true, 0, tsVal), - ), - ) - // Store the parsed TSVal from the ack as recentTS. - tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) - opts := tcpSeg.ParsedOptions() - r.RecentTS = opts.TSVal - return ackPacket -} - -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { - r.C.t.Helper() - _ = r.VerifyAndReturnACKWithTS(tsVal) -} - -// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK -// matches the provided rcvWnd. -func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { - r.C.t.Helper() - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPWindow(rcvWnd), - ), - ) -} - -// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. -func (r *RawEndpoint) VerifyACKNoSACK() { - r.VerifyACKHasSACK(nil) -} - -// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. -func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { - // Read ACK and verify that the TCP options in the segment do - // not contain a SACK block. - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPSACKBlockChecker(sackBlocks), - ), - ) -} - -// CreateConnectedWithOptions creates and connects c.ep with the specified TCP -// options enabled and returns a RawEndpoint which represents the other end of -// the connection. -// -// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK -// does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} - err = c.EP.Connect(testFullAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) - } - // Receive SYN packet. - b := c.GetPacket() - // Validate that the syn has the timestamp option and a valid - // TS value. - mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpSeg := header.TCP(header.IPv4(b).Payload()) - synOptions := header.ParseSynOptions(tcpSeg.Options(), false) - - // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - if wantOptions.WS != -1 { - offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:]) - } - if wantOptions.TS { - offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) - } - if wantOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) - } - - offset += header.AddTCPOptionPadding(synAckOptions, offset) - - // Build SYN-ACK. - c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(TestInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - TCPOpts: synAckOptions[:offset], - }) - - // Read ACK. - ackPacket := c.GetPacket() - - // Verify TCP header fields. - tcpCheckers := []checker.TransportChecker{ - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS) + 1), - checker.TCPAckNum(uint32(iss) + 1), - } - - // Verify that tsEcr of ACK packet is wantOptions.TSVal if the - // timestamp option was enabled, if not then we verify that - // there is no timestamp in the ACK packet. - if wantOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) - - ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) - ackOptions := ackSeg.ParsedOptions() - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Store the source port in use by the endpoint. - c.Port = tcpSeg.SourcePort() - - // Mark in context that timestamp option is enabled for this endpoint. - c.TimeStampEnabled = true - c.RcvdWindowScale = uint8(synOptions.WS) - return &RawEndpoint{ - C: c, - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagAck | header.TCPFlagPsh, - NextSeqNum: iss + 1, - AckNum: c.IRS.Add(1), - WndSize: 30000, - RecentTS: ackOptions.TSVal, - TSVal: wantOptions.TSVal, - SACKPermitted: wantOptions.SACKPermitted, - } -} - -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. -// -// The function returns a RawEndpoint representing the other end of the accepted -// endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - c.t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - c.t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - return rep -} - -// PassiveConnect just disables WindowScaling and delegates the call to -// PassiveConnectWithOptions. -func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { - synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) -} - -// PassiveConnectWithOptions initiates a new connection (with the specified TCP -// options enabled) to the port on which the Context.ep is listening for new -// connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. -// -// NOTE: MSS is not a negotiated option and it can be asymmetric -// in each direction. This function uses the maxPayload to set the MSS to be -// sent to the peer on a connect and validates that the MSS in the SYN-ACK -// response is equal to the MTU - (tcphdr len + iphdr len). -// -// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the -// value of the window scaling option to be sent in the SYN. If synOptions.WS > -// 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - c.t.Helper() - opts := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - offset += header.EncodeMSSOption(uint32(maxPayload), opts) - - if synOptions.WS >= 0 { - offset += header.EncodeWSOption(3, opts[offset:]) - } - if synOptions.TS { - offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) - } - - if synOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(opts[offset:]) - } - - paddingToAdd := 4 - offset%4 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - opts[i] = header.TCPOptionNOP - } - offset += paddingToAdd - - // Send a SYN request. - iss := seqnum.Value(TestInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - TCPOpts: opts[:offset], - }) - - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(StackPort), - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(iss) + 1), - checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), - } - - // If TS option was enabled in the original SYN then add a checker to - // validate the Timestamp option in the SYN-ACK. - if synOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) - rcvWnd := seqnum.Size(30000) - ackHeaders := &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: rcvWnd, - } - - // If WS was expected to be in effect then scale the advertised window - // correspondingly. - if synOptions.WS > 0 { - ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) - } - - parsedOpts := tcp.ParsedOptions() - if synOptions.TS { - // Echo the tsVal back to the peer in the tsEcr field of the - // timestamp option. - // Increment TSVal by 1 from the value sent in the SYN and echo - // the TSVal in the SYN-ACK in the TSEcr field. - opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) - ackHeaders.TCPOpts = opts[:] - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.RcvdWindowScale = uint8(rcvdSynOptions.WS) - c.Port = StackPort - - return &RawEndpoint{ - C: c, - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagPsh | header.TCPFlagAck, - NextSeqNum: iss + 1, - AckNum: c.IRS + 1, - WndSize: rcvWnd, - SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), - RecentTS: parsedOpts.TSVal, - TSVal: synOptions.TSVal + 1, - } -} - -// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true -// for the Stack in the context. -func (c *Context) SACKEnabled() bool { - var v tcpip.TCPSACKEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return false - } - return bool(v) -} - -// SetGSOEnabled enables or disables generic segmentation offload. -func (c *Context) SetGSOEnabled(enable bool) { - if enable { - c.linkEP.SupportedGSOKind = stack.HWGSOSupported - } else { - c.linkEP.SupportedGSOKind = stack.GSONotSupported - } -} - -// MSSWithoutOptions returns the value for the MSS used by the stack when no -// options are in use. -func (c *Context) MSSWithoutOptions() uint16 { - return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) -} - -// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no -// options are in use for IPv6 packets. -func (c *Context) MSSWithoutOptionsV6() uint16 { - return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go deleted file mode 100644 index 479752de7..000000000 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/tcpip/faketime" -) - -func TestCleanup(t *testing.T) { - const ( - timerDurationSeconds = 2 - isAssertedTimeoutSeconds = timerDurationSeconds + 1 - ) - - clock := faketime.NewManualClock() - - tmr := timer{} - w := sleep.Waker{} - tmr.init(clock, &w) - tmr.enable(timerDurationSeconds * time.Second) - tmr.cleanup() - - if want := (timer{}); tmr != want { - t.Errorf("got tmr = %+v, want = %+v", tmr, want) - } - - // The waker should not be asserted. - for i := 0; i < isAssertedTimeoutSeconds; i++ { - clock.Advance(time.Second) - if w.IsAsserted() { - t.Fatalf("waker asserted unexpectedly") - } - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD deleted file mode 100644 index 3ad6994a7..000000000 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpconntrack", - srcs = ["tcp_conntrack.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcpconntrack_test", - size = "small", - srcs = ["tcp_conntrack_test.go"], - deps = [ - ":tcpconntrack", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go deleted file mode 100644 index 6c5ddc3c7..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpconntrack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" -) - -// connected creates a connection tracker TCB and sets it to a connected state -// by performing a 3-way handshake. -func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: iss, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: irw, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: irs, - AckNum: iss + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: isw, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: iss + 1, - AckNum: irs + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: irw, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - return &tcb -} - -func TestConnectionRefused(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive RST. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionRefusedInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionResetInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestRetransmitOnSynSent(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Retransmit the same SYN. - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) - } -} - -func TestRetransmitOnSynRcvd(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. This will cause the state to go to SYN-RCVD. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Retransmit the original SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Transmit a SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} - -func TestClosedBySelf(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Send FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1236, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestClosedByPeer(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Receive FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 791, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) - } -} - -func TestSendAndReceiveDataClosedBySelf(t *testing.T) { - sseq := uint32(1234) - rseq := uint32(789) - tcb := connected(t, sseq, rseq, 30000, 50000) - sseq++ - rseq++ - - // Send some data. - tcp := make(header.TCP, header.TCPMinimumSize+1024) - - for i := uint32(0); i < 10; i++ { - // Send some data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - sseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - for i := uint32(0); i < 10; i++ { - // Receive some data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - rseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - // Send FIN. - tcp = tcp[:header.TCPMinimumSize] - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - sseq++ - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - rseq++ - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestIgnoreBadResetOnSynSent(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive a RST with a bad ACK, it should not cause the connection to - // be reset. - acks := []uint32{1234, 1236, 1000, 5000} - flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} - for _, a := range acks { - for _, f := range flags { - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: a, - DataOffset: header.TCPMinimumSize, - Flags: f, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - } - - // Complete the handshake. - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go new file mode 100644 index 000000000..ff53204da --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcpconntrack diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index cdc344ab7..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100644 index 000000000..c396f77c9 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,221 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *udpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (udpPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *udpPacketList) PushFront(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *udpPacketList) PushBack(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + bLinker := udpPacketElementMapper{}.linkerFor(b) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + aLinker := udpPacketElementMapper{}.linkerFor(a) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *udpPacketList) Remove(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100644 index 000000000..fd87cf524 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,241 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *udpPacket) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacket" +} + +func (p *udpPacket) StateFields() []string { + return []string{ + "udpPacketEntry", + "senderAddress", + "destinationAddress", + "packetInfo", + "data", + "receivedAt", + "tos", + } +} + +func (p *udpPacket) beforeSave() {} + +// +checklocksignore +func (p *udpPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView + dataValue = p.saveData() + stateSinkObject.SaveValue(4, dataValue) + var receivedAtValue int64 + receivedAtValue = p.saveReceivedAt() + stateSinkObject.SaveValue(5, receivedAtValue) + stateSinkObject.Save(0, &p.udpPacketEntry) + stateSinkObject.Save(1, &p.senderAddress) + stateSinkObject.Save(2, &p.destinationAddress) + stateSinkObject.Save(3, &p.packetInfo) + stateSinkObject.Save(6, &p.tos) +} + +func (p *udpPacket) afterLoad() {} + +// +checklocksignore +func (p *udpPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.udpPacketEntry) + stateSourceObject.Load(1, &p.senderAddress) + stateSourceObject.Load(2, &p.destinationAddress) + stateSourceObject.Load(3, &p.packetInfo) + stateSourceObject.Load(6, &p.tos) + stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(5, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/udp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "rcvReady", + "rcvList", + "rcvBufSize", + "rcvClosed", + "state", + "dstPort", + "ttl", + "multicastTTL", + "multicastAddr", + "multicastNICID", + "portFlags", + "lastError", + "boundBindToDevice", + "boundPortFlags", + "sendTOS", + "shutdownFlags", + "multicastMemberships", + "effectiveNetProtos", + "owner", + "ops", + "frozen", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.rcvReady) + stateSinkObject.Save(5, &e.rcvList) + stateSinkObject.Save(6, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.state) + stateSinkObject.Save(9, &e.dstPort) + stateSinkObject.Save(10, &e.ttl) + stateSinkObject.Save(11, &e.multicastTTL) + stateSinkObject.Save(12, &e.multicastAddr) + stateSinkObject.Save(13, &e.multicastNICID) + stateSinkObject.Save(14, &e.portFlags) + stateSinkObject.Save(15, &e.lastError) + stateSinkObject.Save(16, &e.boundBindToDevice) + stateSinkObject.Save(17, &e.boundPortFlags) + stateSinkObject.Save(18, &e.sendTOS) + stateSinkObject.Save(19, &e.shutdownFlags) + stateSinkObject.Save(20, &e.multicastMemberships) + stateSinkObject.Save(21, &e.effectiveNetProtos) + stateSinkObject.Save(22, &e.owner) + stateSinkObject.Save(23, &e.ops) + stateSinkObject.Save(24, &e.frozen) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.rcvReady) + stateSourceObject.Load(5, &e.rcvList) + stateSourceObject.Load(6, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.state) + stateSourceObject.Load(9, &e.dstPort) + stateSourceObject.Load(10, &e.ttl) + stateSourceObject.Load(11, &e.multicastTTL) + stateSourceObject.Load(12, &e.multicastAddr) + stateSourceObject.Load(13, &e.multicastNICID) + stateSourceObject.Load(14, &e.portFlags) + stateSourceObject.Load(15, &e.lastError) + stateSourceObject.Load(16, &e.boundBindToDevice) + stateSourceObject.Load(17, &e.boundPortFlags) + stateSourceObject.Load(18, &e.sendTOS) + stateSourceObject.Load(19, &e.shutdownFlags) + stateSourceObject.Load(20, &e.multicastMemberships) + stateSourceObject.Load(21, &e.effectiveNetProtos) + stateSourceObject.Load(22, &e.owner) + stateSourceObject.Load(23, &e.ops) + stateSourceObject.Load(24, &e.frozen) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (m *multicastMembership) StateTypeName() string { + return "pkg/tcpip/transport/udp.multicastMembership" +} + +func (m *multicastMembership) StateFields() []string { + return []string{ + "nicID", + "multicastAddr", + } +} + +func (m *multicastMembership) beforeSave() {} + +// +checklocksignore +func (m *multicastMembership) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.nicID) + stateSinkObject.Save(1, &m.multicastAddr) +} + +func (m *multicastMembership) afterLoad() {} + +// +checklocksignore +func (m *multicastMembership) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.nicID) + stateSourceObject.Load(1, &m.multicastAddr) +} + +func (l *udpPacketList) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacketList" +} + +func (l *udpPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *udpPacketList) beforeSave() {} + +// +checklocksignore +func (l *udpPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *udpPacketList) afterLoad() {} + +// +checklocksignore +func (l *udpPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *udpPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacketEntry" +} + +func (e *udpPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *udpPacketEntry) beforeSave() {} + +// +checklocksignore +func (e *udpPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *udpPacketEntry) afterLoad() {} + +// +checklocksignore +func (e *udpPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*udpPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*multicastMembership)(nil)) + state.Register((*udpPacketList)(nil)) + state.Register((*udpPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index 4008cacf2..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,2559 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "fmt" - "io/ioutil" - "math/rand" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Addresses and ports used for testing. It is recommended that tests stick to -// using these addresses as it allows using the testFlow helper. -// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' -// represents the remote endpoint. -const ( - v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = v4MappedAddrPrefix + stackAddr - testV4MappedAddr = v4MappedAddrPrefix + testAddr - multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr - broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr - v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - invalidPort = 8192 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in -// a packet header. These values are used to populate a header or verify one. -// Note that because they are used in packet headers, the addresses are never in -// a V4-mapped format. -type header4Tuple struct { - srcAddr tcpip.FullAddress - dstAddr tcpip.FullAddress -} - -// testFlow implements a helper type used for sending and receiving test -// packets. A given test flow value defines 1) the socket endpoint used for the -// test and 2) the type of packet send or received on the endpoint. E.g., a -// multicastV6Only flow is a V6 multicast packet passing through a V6-only -// endpoint. The type provides helper methods to characterize the flow (e.g., -// isV4) as well as return a proper header4Tuple for it. -type testFlow int - -const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket - reverseMulticast4 // V4 multicast src. Must fail. - reverseMulticast6 // V6 multicast src. Must fail. -) - -func (flow testFlow) String() string { - switch flow { - case unicastV4: - return "unicastV4" - case unicastV6: - return "unicastV6" - case unicastV6Only: - return "unicastV6Only" - case unicastV4in6: - return "unicastV4in6" - case multicastV4: - return "multicastV4" - case multicastV6: - return "multicastV6" - case multicastV6Only: - return "multicastV6Only" - case multicastV4in6: - return "multicastV4in6" - case broadcast: - return "broadcast" - case broadcastIn6: - return "broadcastIn6" - case reverseMulticast4: - return "reverseMulticast4" - case reverseMulticast6: - return "reverseMulticast6" - default: - return "unknown" - } -} - -// packetDirection explains if a flow is incoming (read) or outgoing (write). -type packetDirection int - -const ( - incoming packetDirection = iota - outgoing -) - -// header4Tuple returns the header4Tuple for the given flow and direction. Note -// that the tuple contains no mapped addresses as those only exist at the socket -// level but not at the packet header level. -func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { - var h header4Tuple - if flow.isV4() { - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastAddr - } else if flow.isBroadcast() { - h.dstAddr.Addr = broadcastAddr - } - } else { // IPv6 - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastV6Addr - } - } - if flow.isReverseMulticast() { - h.srcAddr.Addr = flow.getMcastAddr() - } - return h -} - -func (flow testFlow) getMcastAddr() tcpip.Address { - if flow.isV4() { - return multicastAddr - } - return multicastV6Addr -} - -// mapAddrIfApplicable converts the given V4 address into its V4-mapped version -// if it is applicable to the flow. -func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { - if flow.isMapped() { - return v4MappedAddrPrefix + v4Addr - } - return v4Addr -} - -// netProto returns the protocol number used for the network packet. -func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { - if flow.isV4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber -} - -// sockProto returns the protocol number used when creating the socket -// endpoint for this flow. -func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { - switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: - return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast, reverseMulticast4: - return ipv4.ProtocolNumber - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { - if flow.isV4() { - return checker.IPv4 - } - return checker.IPv6 -} - -func (flow testFlow) isV6() bool { return !flow.isV4() } -func (flow testFlow) isV4() bool { - return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() -} - -func (flow testFlow) isV6Only() bool { - switch flow { - case unicastV6Only, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMulticast() bool { - switch flow { - case multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isBroadcast() bool { - switch flow { - case broadcast, broadcastIn6: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMapped() bool { - switch flow { - case unicastV4in6, multicastV4in6, broadcastIn6: - return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isReverseMulticast() bool { - switch flow { - case reverseMulticast4, reverseMulticast6: - return true - default: - return false - } -} - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - t.Helper() - return newDualTestContextWithHandleLocal(t, mtu, true) -} - -func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { - t.Helper() - - options := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: handleLocal, - Clock: &faketime.NullClock{}, - } - s := stack.New(options) - ep := channel.New(256, mtu, "") - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %s", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %s", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { - c.t.Helper() - - var err tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) - if err != nil { - c.t.Fatal("NewEndpoint failed: ", err) - } -} - -func (c *testContext) createEndpointForFlow(flow testFlow) { - c.t.Helper() - - c.createEndpoint(flow.sockProto()) - if flow.isV6Only() { - c.ep.SocketOptions().SetV6Only(true) - } else if flow.isBroadcast() { - c.ep.SocketOptions().SetBroadcast(true) - } -} - -// getPacketAndVerify reads a packet from the link endpoint and verifies the -// header against expected values from the given test flow. In addition, it -// calls any extra checker functions provided. -func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { - c.t.Helper() - - p, ok := c.linkEP.Read() - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { - c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - h := flow.header4Tuple(outgoing) - checkers = append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b -} - -// injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. If badChecksum is true, the packet has -// a bad checksum in the UDP header. -func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) { - c.t.Helper() - - h := flow.header4Tuple(incoming) - if flow.isV4() { - buf := c.buildV4Packet(payload, &h) - if badChecksum { - // Invalidate the UDP header checksum field, taking care to avoid - // overflow to zero, which would disable checksum validation. - for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { - u.SetChecksum(u.Checksum() + 1) - if u.Checksum() != 0 { - break - } - } - } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } else { - buf := c.buildV6Packet(payload, &h) - if badChecksum { - // Invalidate the UDP header checksum field (Unlike IPv4, zero is - // a valid checksum value for IPv6 so no need to avoid it). - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) - } - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } -} - -// buildV6Packet creates a V6 test packet with the given payload and header -// values in a buffer. -func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - return buf -} - -// buildV4Packet creates a V4 test packet with the given payload and header -// values in a buffer. -func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TOS: testTOS, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - return buf -} - -func newPayload() []byte { - return newMinPayload(30) -} - -func newMinPayload(minSize int) []byte { - b := make([]byte, minSize+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: &faketime.NullClock{}, - }) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) - } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) - } - }) - } -} - -// testReadInternal sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then attempts to read it from the -// UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - - payload := newPayload() - c.injectPacket(flow, payload, false) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.ReadableEvents) - defer c.wq.EventUnregister(&we) - - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - var buf bytes.Buffer - res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for data to become available. - select { - case <-ch: - res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - - default: - if packetShouldBeDropped { - return // expected to time out - } - c.t.Fatal("timed out waiting for data") - } - } - - if expectReadError && err != nil { - c.checkEndpointReadStats(1, epstats, err) - return - } - - if err != nil { - c.t.Fatal("Read failed:", err) - } - - if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) - } - - // Check the read result. - h := flow.header4Tuple(incoming) - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", // ControlMessages will be checked later. - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff) - } - - // Check the payload. - v := buf.Bytes() - if !bytes.Equal(payload, v) { - c.t.Fatalf("got payload = %x, want = %x", v, payload) - } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, res.ControlMessages) - } - - c.checkEndpointReadStats(1, epstats, err) -} - -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) -} - -// testFailingRead sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then tries to read it from the UDP -// endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { - c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %s", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - { - err := ep.Bind(addr) - if _, ok := err.(*tcpip.ErrPortInUse); !ok { - t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) - } - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - { - err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) - if _, ok := err.(*tcpip.ErrPortInUse); !ok { - t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) - } - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV6) -} - -// TestV4ReadSelfSource checks that packets coming from a local IP address are -// correctly dropped when handleLocal is true and not otherwise. -func TestV4ReadSelfSource(t *testing.T) { - for _, tt := range []struct { - name string - handleLocal bool - wantErr tcpip.Error - wantInvalidSource uint64 - }{ - {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, - } { - t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - h.srcAddr = h.dstAddr - - buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { - t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) - } - - if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr { - t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) - } - }) - } -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4) -} - -// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast -// address and receive data sent to that address. -func TestReadOnBoundToMulticast(t *testing.T) { - // FIXME(b/128189410): multicastV4in6 currently doesn't work as - // AddMembershipOption doesn't handle V4in6 addresses. - for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to multicast address. - mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) - if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - // Join multicast group. - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) - } - - // Check that we receive multicast packets but not unicast or broadcast - // ones. - testRead(c, flow) - testFailingRead(c, broadcast, false /* expectReadError */) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and can receive only broadcast data. -func TestV4ReadOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to broadcast address. - bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) - if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Check that we receive broadcast packets but not unicast ones. - testRead(c, flow) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestReadFromMulticast checks that an endpoint will NOT receive a packet -// that was sent with multicast SOURCE address. -func TestReadFromMulticast(t *testing.T) { - for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - testFailingRead(c, flow, false /* expectReadError */) - }) - } -} - -// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY -// and receive broadcast and unicast data. -func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s (", err) - } - - // Check that we receive both broadcast and unicast packets. - testRead(c, flow) - testRead(c, unicastV4) - }) - } -} - -// testFailingWrite sends a packet of the given test flow into the UDP endpoint -// and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - - var r bytes.Reader - r.Reset(newPayload()) - _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - }) - c.checkEndpointWriteStats(1, epstats, gotErr) - if gotErr != wantErr { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) - } -} - -// testWrite sends a packet of the given test flow from the UDP endpoint to the -// flow's destination address:port. It then receives it from the link endpoint -// and verifies its correctness including any additional checker functions -// provided. -func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteAndVerifyInternal(c, flow, true, checkers...) -} - -// testWriteWithoutDestination sends a packet of the given test flow from the -// UDP endpoint without giving a destination address:port. It then receives it -// from the link endpoint and verifies its correctness including any additional -// checker functions provided. -func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteAndVerifyInternal(c, flow, false, checkers...) -} - -func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - writeOpts := tcpip.WriteOptions{} - if setDest { - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - writeOpts = tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - } - } - var r bytes.Reader - payload := newPayload() - r.Reset(payload) - n, err := c.ep.Write(&r, writeOpts) - if err != nil { - c.t.Fatalf("Write failed: %s", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - c.checkEndpointWriteStats(1, epstats, err) - return payload -} - -func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - payload := testWriteNoVerify(c, flow, setDest) - // Received the packet and check the payload. - b := c.getPacketAndVerify(flow, checkers...) - var udpH header.UDP - if flow.isV4() { - udpH = header.IPv4(b).Payload() - } else { - udpH = header.IPv6(b).Payload() - } - if !bytes.Equal(payload, udpH.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) - } - - return udpH.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - c.t.Helper() - - v4Port := testWrite(c, unicastV4in6) - v6Port := testWrite(c, unicastV6) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, unicastV6) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) - const want = 1 - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { - c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, unicastV4in6) - - // Write to v6 address. - testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6Only) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Write to v6 address. - testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - testWriteWithoutDestination(c, unicastV6) -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - testWriteWithoutDestination(c, unicastV4) -} - -func TestWriteOnConnectedInvalidPort(t *testing.T) { - protocols := map[string]tcpip.NetworkProtocolNumber{ - "ipv4": ipv4.ProtocolNumber, - "ipv6": ipv6.ProtocolNumber, - } - for name, pn := range protocols { - t.Run(name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(pn) - if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - writeOpts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, - } - var r bytes.Reader - payload := newPayload() - r.Reset(payload) - n, err := c.ep.Write(&r, writeOpts) - if err != nil { - c.t.Fatalf("c.ep.Write(...) = %s, want nil", err) - } - if got, want := n, int64(len(payload)); got != want { - c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) - } - - { - err := c.ep.LastError() - if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) - } - } - }) - } -} - -// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket -// that is bound to a V4 multicast address. -func TestWriteOnBoundToV4Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a -// socket that is bound to a V4-mapped multicast address. -func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV6, multicastV6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// V6-only socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToBroadcast checks that we can send packets out of a -// socket that is bound to the broadcast address. -func TestWriteOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 broadcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a -// socket that is bound to the V4-mapped broadcast address. -func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testRead(c, unicastV4) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestReadIPPacketInfo(t *testing.T) { - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedLocalAddr tcpip.Address - expectedDestAddr tcpip.Address - }{ - { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedLocalAddr: stackAddr, - expectedDestAddr: stackAddr, - }, - { - name: "IPv4 multicast", - proto: header.IPv4ProtocolNumber, - flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastAddr, - expectedDestAddr: multicastAddr, - }, - { - name: "IPv4 broadcast", - proto: header.IPv4ProtocolNumber, - flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: broadcastAddr, - expectedDestAddr: broadcastAddr, - }, - { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedLocalAddr: stackV6Addr, - expectedDestAddr: stackV6Addr, - }, - { - name: "IPv6 multicast", - proto: header.IPv6ProtocolNumber, - flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastV6Addr, - expectedDestAddr: multicastV6Addr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(test.proto) - - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - if test.flow.isMulticast() { - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) - } - } - - c.ep.SocketOptions().SetReceivePacketInfo(true) - - testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ - NIC: 1, - LocalAddr: test.expectedLocalAddr, - DestinationAddr: test.expectedDestAddr, - })) - - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { - t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) - } - }) - } -} - -func TestReadRecvOriginalDstAddr(t *testing.T) { - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedOriginalDstAddr tcpip.FullAddress - }{ - { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, - }, - { - name: "IPv4 multicast", - proto: header.IPv4ProtocolNumber, - flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, - }, - { - name: "IPv4 broadcast", - proto: header.IPv4ProtocolNumber, - flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, - }, - { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, - }, - { - name: "IPv6 multicast", - proto: header.IPv6ProtocolNumber, - flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(test.proto) - - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%#v): %s", bindAddr, err) - } - - if test.flow.isMulticast() { - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) - } - } - - c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) - - testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) - - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { - t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) - } - }) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestNoChecksum(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Disable the checksum generation. - c.ep.SocketOptions().SetNoChecksum(true) - // This option is effective on IPv4 only. - testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) - - // Enable the checksum generation. - c.ep.SocketOptions().SetNoChecksum(false) - testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) - }) - } -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface -} - -func (*testInterface) ID() tcpip.NICID { - return 0 -} - -func (*testInterface) Enabled() bool { - return true -} - -func TestTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const multicastTTL = 42 - if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { - c.t.Fatalf("SetSockOptInt failed: %s", err) - } - - var wantTTL uint8 - if flow.isMulticast() { - wantTTL = multicastTTL - } else { - var p stack.NetworkProtocolFactory - var n tcpip.NetworkProtocolNumber - if flow.isV4() { - p = ipv4.NewProtocol - n = ipv4.ProtocolNumber - } else { - p = ipv6.NewProtocol - n = ipv6.ProtocolNumber - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{p}, - Clock: &faketime.NullClock{}, - }) - ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) - wantTTL = ep.DefaultTTL() - ep.Close() - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } -} - -func TestSetTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { - c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } - }) - } -} - -func TestSetTOS(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = testTOS - v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) - } - - if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) - } - - v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) - } - - if v != tos { - c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestSetTClass(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tClass = testTOS - v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) - } - - if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { - c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) - } - - v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) - } - - if v != tClass { - c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) - } - - // The header getter for TClass is called TOS, so use that checker. - testWrite(c, flow, checker.TOS(tClass, 0)) - }) - } -} - -func TestReceiveTosTClass(t *testing.T) { - const RcvTOSOpt = "ReceiveTosOption" - const RcvTClassOpt = "ReceiveTClassOption" - - testCases := []struct { - name string - tests []testFlow - }{ - {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, - {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, - } - for _, testCase := range testCases { - for _, flow := range testCase.tests { - t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - name := testCase.name - - var optionGetter func() bool - var optionSetter func(bool) - switch name { - case RcvTOSOpt: - optionGetter = c.ep.SocketOptions().GetReceiveTOS - optionSetter = c.ep.SocketOptions().SetReceiveTOS - case RcvTClassOpt: - optionGetter = c.ep.SocketOptions().GetReceiveTClass - optionSetter = c.ep.SocketOptions().SetReceiveTClass - default: - t.Fatalf("unkown test variant: %s", name) - } - - // Verify that setting and reading the option works. - v := optionGetter() - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) - } - - const want = true - optionSetter(want) - - got := optionGetter() - if got != want { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) - } - - // Verify that the correct received TOS or TClass is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - switch name { - case RcvTClassOpt: - testRead(c, flow, checker.ReceiveTClass(testTOS)) - case RcvTOSOpt: - testRead(c, flow, checker.ReceiveTOS(testTOS)) - default: - t.Fatalf("unknown test variant: %s", name) - } - }) - } - } -} - -func TestMulticastInterfaceOption(t *testing.T) { - for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - h := flow.header4Tuple(outgoing) - mcastAddr := h.dstAddr.Addr - localIfAddr := h.srcAddr.Addr - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: flow.mapAddrIfApplicable(mcastAddr), - Port: stackPort, - } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - } - - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) - } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { - c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) - } -} - -// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV4UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true, will result in a payload large enough - // so that the final generated IPv4 packet is larger than - // header.IPv4MinimumProcessableDatagramSize. - largePayload bool - // badChecksum if true, will set an invalid checksum in the - // header. - badChecksum bool - }{ - {unicastV4, true, false, false}, - {unicastV4, true, true, false}, - {unicastV4, false, false, true}, - {unicastV4, false, true, true}, - {multicastV4, false, false, false}, - {multicastV4, false, true, false}, - {broadcast, false, false, false}, - {broadcast, false, true, false}, - } - checksumErrors := uint64(0) - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(576) - } - c.injectPacket(tc.flow, payload, tc.badChecksum) - if tc.badChecksum { - checksumErrors++ - if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { - t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - } - if !tc.icmpRequired { - if p, ok := c.linkEP.Read(); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - p, ok := c.linkEP.Read() - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - - // We need to compare the included data part of the UDP packet that is in - // the ICMP packet with the matching original data. - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantLen := len(payload) - if tc.largePayload { - // To work out the data size we need to simulate what the sender would - // have done. The wanted size is the total available minus the sum of - // the headers in the UDP AND ICMP packets, given that we know the test - // had only a minimal IP header but the ICMP sender will have allowed - // for a maximally sized packet header. - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } - - // In the case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - }) - } -} - -// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV6UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true will result in a payload large enough to - // create an IPv6 packet > header.IPv6MinimumMTU bytes. - largePayload bool - // badChecksum if true, will set an invalid checksum in the - // header. - badChecksum bool - }{ - {unicastV6, true, false, false}, - {unicastV6, true, true, false}, - {unicastV6, false, false, true}, - {unicastV6, false, true, true}, - {multicastV6, false, false, false}, - {multicastV6, false, true, false}, - } - checksumErrors := uint64(0) - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(1280) - } - c.injectPacket(tc.flow, payload, tc.badChecksum) - if tc.badChecksum { - checksumErrors++ - if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { - t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - } - if !tc.icmpRequired { - if p, ok := c.linkEP.Read(); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - p, ok := c.linkEP.Read() - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - }) - } -} - -// TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats are incremented. -func TestIncrementMalformedPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - - // Invalidate the UDP header length field. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetLength(u.Length() + 1) - - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } -} - -// TestShortHeader verifies that when a packet with a too-short UDP header is -// received, the malformed received global stat gets incremented. -func TestShortHeader(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - h := unicastV6.header4Tuple(incoming) - - // Allocate a buffer for an IPv6 and too-short UDP header. - const udpSize = header.UDPMinimumSize - 1 - buf := buffer.NewView(header.IPv6MinimumSize + udpSize) - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: header.UDPMinimumSize, - }) - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - // Copy all but the last byte of the UDP header into the packet. - copy(buf[header.IPv6MinimumSize:], udpHdr) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) - } -} - -// TestBadChecksumErrors verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestBadChecksumErrors(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV6} { - t.Run(flow.String(), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } - }) - } -} - -// TestPayloadModifiedV4 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestPayloadModifiedV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be - // incorrect. - buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestPayloadModifiedV6 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestPayloadModifiedV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be - // incorrect. - buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestChecksumZeroV4 verifies if the checksum value is zero, global and -// endpoint states are *not* incremented (UDP checksum is optional on IPv4). -func TestChecksumZeroV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Set the checksum field in the UDP header to zero. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 0 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestChecksumZeroV6 verifies if the checksum value is zero, global and -// endpoint states are incremented (UDP checksum is *not* optional on IPv6). -func TestChecksumZeroV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Set the checksum field in the UDP header to zero. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestShutdownRead verifies endpoint read shutdown and error -// stats increment on packet receive. -func TestShutdownRead(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - testFailingRead(c, unicastV6, true /* expectReadError */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { - t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) - } -} - -// TestShutdownWrite verifies endpoint write shutdown and error -// stats increment on packet write. -func TestShutdownWrite(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) -} - -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err.(type) { - case nil: - want.PacketsSent.IncrementBy(incr) - case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: - want.WriteErrors.InvalidArgs.IncrementBy(incr) - case *tcpip.ErrClosedForSend: - want.WriteErrors.WriteClosed.IncrementBy(incr) - case *tcpip.ErrInvalidEndpointState: - want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: - want.SendErrors.NoRoute.IncrementBy(incr) - default: - want.SendErrors.SendToNetworkFailed.IncrementBy(incr) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - case *tcpip.ErrClosedForReceive: - want.ReadErrors.ReadClosed.IncrementBy(incr) - default: - c.t.Errorf("Endpoint error missing stats update err %v", err) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func TestOutgoingSubnetBroadcast(t *testing.T) { - const nicID1 = 1 - - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := testutil.MustParse4("192.168.1.1") - ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 31, - } - ipv4Subnet31 := ipv4AddrPrefix31.Subnet() - ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() - ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 32, - } - ipv4Subnet32 := ipv4AddrPrefix32.Subnet() - ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - remNetAddr := tcpip.AddressWithPrefix{ - Address: "\x64\x0a\x7b\x18", - PrefixLen: 24, - } - remNetSubnet := remNetAddr.Subnet() - remNetSubnetBcast := remNetSubnet.Broadcast() - - tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - requiresBroadcastOpt bool - }{ - { - name: "IPv4 Broadcast to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv4SubnetBcast, - requiresBroadcastOpt: true, - }, - { - name: "IPv4 Broadcast to local /31 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix31, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet31, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet31Bcast, - requiresBroadcastOpt: false, - }, - { - name: "IPv4 Broadcast to local /32 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix32, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet32, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet32Bcast, - requiresBroadcastOpt: false, - }, - // IPv6 has no notion of a broadcast. - { - name: "IPv6 'Broadcast' to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: ipv6Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv6Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv6SubnetBcast, - requiresBroadcastOpt: false, - }, - { - name: "IPv4 Broadcast to remote subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: remNetSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - // TODO(gvisor.dev/issue/3938): Once we support marking a route as - // broadcast, this test should require the broadcast option to be set. - requiresBroadcastOpt: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: &faketime.NullClock{}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) - } - - s.SetRouteTable(test.routes) - - var netProto tcpip.NetworkProtocolNumber - switch l := len(test.remoteAddr); l { - case header.IPv4AddressSize: - netProto = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - netProto = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) - } - defer ep.Close() - - var r bytes.Reader - data := []byte{1, 2, 3, 4} - to := tcpip.FullAddress{ - Addr: test.remoteAddr, - Port: 80, - } - opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { - if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { - return nil - } - return &tcpip.ErrBroadcastDisabled{} - } - if !test.requiresBroadcastOpt { - expectedErrWithoutBcastOpt = nil - } - - r.Reset(data) - { - n, err := ep.Write(&r, opts) - if expectedErrWithoutBcastOpt != nil { - if want := expectedErrWithoutBcastOpt(err); want != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) - } - } else if err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - } - - ep.SocketOptions().SetBroadcast(true) - - r.Reset(data) - if n, err := ep.Write(&r, opts); err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - - ep.SocketOptions().SetBroadcast(false) - - r.Reset(data) - { - n, err := ep.Write(&r, opts) - if expectedErrWithoutBcastOpt != nil { - if want := expectedErrWithoutBcastOpt(err); want != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) - } - } else if err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - } - }) - } -} diff --git a/pkg/test/criutil/BUILD b/pkg/test/criutil/BUILD deleted file mode 100644 index a7b082cee..000000000 --- a/pkg/test/criutil/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "criutil", - testonly = 1, - srcs = ["criutil.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/test/dockerutil", - "//pkg/test/testutil", - ], -) diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go deleted file mode 100644 index 3b41a2824..000000000 --- a/pkg/test/criutil/criutil.go +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package criutil contains utility functions for interacting with the -// Container Runtime Interface (CRI), principally via the crictl command line -// tool. This requires critools to be installed on the local system. -package criutil - -import ( - "encoding/json" - "fmt" - "os" - "os/exec" - "path" - "regexp" - "strconv" - "strings" - "time" - - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -// Crictl contains information required to run the crictl utility. -type Crictl struct { - logger testutil.Logger - endpoint string - cleanup []func() -} - -// ResolvePath attempts to find binary paths. It may set the path to invalid, -// which will cause the execution to fail with a sensible error. -func ResolvePath(executable string) string { - runtime, err := dockerutil.RuntimePath() - if err == nil { - // Check first the directory of the runtime itself. - if dir := path.Dir(runtime); dir != "" && dir != "." { - guess := path.Join(dir, executable) - if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 { - return guess - } - } - } - - // Favor /usr/local/bin, if it exists. - localBin := fmt.Sprintf("/usr/local/bin/%s", executable) - if _, err := os.Stat(localBin); err == nil { - return localBin - } - - // Try to find via the path. - guess, _ := exec.LookPath(executable) - if err == nil { - return guess - } - - // Return a bare path; this generates a suitable error. - return executable -} - -// NewCrictl returns a Crictl configured with a timeout and an endpoint over -// which it will talk to containerd. -func NewCrictl(logger testutil.Logger, endpoint string) *Crictl { - // Attempt to find the executable, but don't bother propagating the - // error at this point. The first command executed will return with a - // binary not found error. - return &Crictl{ - logger: logger, - endpoint: endpoint, - } -} - -// CleanUp executes cleanup functions. -func (cc *Crictl) CleanUp() { - for _, c := range cc.cleanup { - c() - } - cc.cleanup = nil -} - -// RunPod creates a sandbox. It corresponds to `crictl runp`. -func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) { - podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile) - if err != nil { - return "", fmt.Errorf("runp failed: %v", err) - } - // Strip the trailing newline from crictl output. - return strings.TrimSpace(podID), nil -} - -// Create creates a container within a sandbox. It corresponds to `crictl -// create`. -func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) { - // In version 1.16.0, crictl annoying starting attempting to pull the - // container, even if it was already available locally. We therefore - // need to parse the version and add an appropriate --no-pull argument - // since the image has already been loaded locally. - out, err := cc.run("-v") - if err != nil { - return "", err - } - r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])") - vs := r.FindStringSubmatch(out) - if len(vs) != 4 { - return "", fmt.Errorf("crictl -v had unexpected output: %s", out) - } - major, err := strconv.ParseUint(vs[1], 10, 64) - if err != nil { - return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) - } - minor, err := strconv.ParseUint(vs[2], 10, 64) - if err != nil { - return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) - } - - args := []string{"create"} - if (major == 1 && minor >= 16) || major > 1 { - args = append(args, "--no-pull") - } - args = append(args, podID) - args = append(args, contSpecFile) - args = append(args, sbSpecFile) - - podID, err = cc.run(args...) - if err != nil { - time.Sleep(10 * time.Minute) // XXX - return "", fmt.Errorf("create failed: %v", err) - } - - // Strip the trailing newline from crictl output. - return strings.TrimSpace(podID), nil -} - -// Start starts a container. It corresponds to `crictl start`. -func (cc *Crictl) Start(contID string) (string, error) { - output, err := cc.run("start", contID) - if err != nil { - return "", fmt.Errorf("start failed: %v", err) - } - return output, nil -} - -// Stop stops a container. It corresponds to `crictl stop`. -func (cc *Crictl) Stop(contID string) error { - _, err := cc.run("stop", contID) - return err -} - -// Exec execs a program inside a container. It corresponds to `crictl exec`. -func (cc *Crictl) Exec(contID string, args ...string) (string, error) { - a := []string{"exec", contID} - a = append(a, args...) - output, err := cc.run(a...) - if err != nil { - return "", fmt.Errorf("exec failed: %v", err) - } - return output, nil -} - -// Logs retrieves the container logs. It corresponds to `crictl logs`. -func (cc *Crictl) Logs(contID string, args ...string) (string, error) { - a := []string{"logs", contID} - a = append(a, args...) - output, err := cc.run(a...) - if err != nil { - return "", fmt.Errorf("logs failed: %v", err) - } - return output, nil -} - -// Rm removes a container. It corresponds to `crictl rm`. -func (cc *Crictl) Rm(contID string) error { - _, err := cc.run("rm", contID) - return err -} - -// StopPod stops a pod. It corresponds to `crictl stopp`. -func (cc *Crictl) StopPod(podID string) error { - _, err := cc.run("stopp", podID) - return err -} - -// containsConfig is a minimal copy of -// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/apis/cri/runtime/v1alpha2/api.proto -// It only contains fields needed for testing. -type containerConfig struct { - Status containerStatus -} - -type containerStatus struct { - Network containerNetwork -} - -type containerNetwork struct { - IP string -} - -// PodIP returns a pod's IP address. -func (cc *Crictl) PodIP(podID string) (string, error) { - output, err := cc.run("inspectp", podID) - if err != nil { - return "", err - } - conf := &containerConfig{} - if err := json.Unmarshal([]byte(output), conf); err != nil { - return "", fmt.Errorf("failed to unmarshal JSON: %v, %s", err, output) - } - if conf.Status.Network.IP == "" { - return "", fmt.Errorf("no IP found in config: %s", output) - } - return conf.Status.Network.IP, nil -} - -// RmPod removes a container. It corresponds to `crictl rmp`. -func (cc *Crictl) RmPod(podID string) error { - _, err := cc.run("rmp", podID) - return err -} - -// Import imports the given container from the local Docker instance. -func (cc *Crictl) Import(image string) error { - // Note that we provide a 10 minute timeout after connect because we may - // be pushing a lot of bytes in order to import the image. The connect - // timeout stays the same and is inherited from the Crictl instance. - cmd := testutil.Command(cc.logger, - ResolvePath("ctr"), - fmt.Sprintf("--connect-timeout=%s", 30*time.Second), - fmt.Sprintf("--address=%s", cc.endpoint), - "-n", "k8s.io", "images", "import", "-") - cmd.Stderr = os.Stderr // Pass through errors. - - // Create a pipe and start the program. - w, err := cmd.StdinPipe() - if err != nil { - return err - } - if err := cmd.Start(); err != nil { - return err - } - - // Save the image on the other end. - if err := dockerutil.Save(cc.logger, image, w); err != nil { - cmd.Wait() - return err - } - - // Close our pipe reference & see if it was loaded. - if err := w.Close(); err != nil { - return w.Close() - } - - return cmd.Wait() -} - -// StartContainer pulls the given image ands starts the container in the -// sandbox with the given podID. -// -// Note that the image will always be imported from the local docker daemon. -func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) { - if err := cc.Import(image); err != nil { - return "", err - } - - // Write the specs to files that can be read by crictl. - sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec) - if err != nil { - return "", fmt.Errorf("failed to write sandbox spec: %v", err) - } - cc.cleanup = append(cc.cleanup, cleanup) - contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec) - if err != nil { - return "", fmt.Errorf("failed to write container spec: %v", err) - } - cc.cleanup = append(cc.cleanup, cleanup) - - return cc.startContainer(podID, image, sbSpecFile, contSpecFile) -} - -func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) { - contID, err := cc.Create(podID, contSpecFile, sbSpecFile) - if err != nil { - return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err) - } - - if _, err := cc.Start(contID); err != nil { - return "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err) - } - - return contID, nil -} - -// StopContainer stops and deletes the container with the given container ID. -func (cc *Crictl) StopContainer(contID string) error { - if err := cc.Stop(contID); err != nil { - return fmt.Errorf("failed to stop container %q: %v", contID, err) - } - - if err := cc.Rm(contID); err != nil { - return fmt.Errorf("failed to remove container %q: %v", contID, err) - } - - return nil -} - -// StartPodAndContainer starts a sandbox and container in that sandbox. It -// returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) { - if err := cc.Import(image); err != nil { - return "", "", err - } - - // Write the specs to files that can be read by crictl. - sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec) - if err != nil { - return "", "", fmt.Errorf("failed to write sandbox spec: %v", err) - } - cc.cleanup = append(cc.cleanup, cleanup) - contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec) - if err != nil { - return "", "", fmt.Errorf("failed to write container spec: %v", err) - } - cc.cleanup = append(cc.cleanup, cleanup) - - podID, err := cc.RunPod(runtime, sbSpecFile) - if err != nil { - return "", "", err - } - - contID, err := cc.startContainer(podID, image, sbSpecFile, contSpecFile) - - return podID, contID, err -} - -// StopPodAndContainer stops a container and pod. -func (cc *Crictl) StopPodAndContainer(podID, contID string) error { - if err := cc.StopContainer(contID); err != nil { - return fmt.Errorf("failed to stop container %q in pod %q: %v", contID, podID, err) - } - - if err := cc.StopPod(podID); err != nil { - return fmt.Errorf("failed to stop pod %q: %v", podID, err) - } - - if err := cc.RmPod(podID); err != nil { - return fmt.Errorf("failed to remove pod %q: %v", podID, err) - } - - return nil -} - -// run runs crictl with the given args. -func (cc *Crictl) run(args ...string) (string, error) { - defaultArgs := []string{ - ResolvePath("crictl"), - "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), - "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), - } - fullArgs := append(defaultArgs, args...) - out, err := testutil.Command(cc.logger, fullArgs...).CombinedOutput() - return string(out), err -} diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD deleted file mode 100644 index 366f068e3..000000000 --- a/pkg/test/dockerutil/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "dockerutil", - testonly = 1, - srcs = [ - "container.go", - "dockerutil.go", - "exec.go", - "network.go", - "profile.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/test/testutil", - "@com_github_docker_docker//api/types:go_default_library", - "@com_github_docker_docker//api/types/container:go_default_library", - "@com_github_docker_docker//api/types/mount:go_default_library", - "@com_github_docker_docker//api/types/network:go_default_library", - "@com_github_docker_docker//client:go_default_library", - "@com_github_docker_docker//pkg/stdcopy:go_default_library", - "@com_github_docker_go_connections//nat:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "profile_test", - size = "large", - srcs = [ - "profile_test.go", - ], - library = ":dockerutil", - tags = [ - # Requires docker and runsc to be configured before test runs. - # Also requires the test to be run as root. - "local", - "manual", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md deleted file mode 100644 index 870292096..000000000 --- a/pkg/test/dockerutil/README.md +++ /dev/null @@ -1,86 +0,0 @@ -# dockerutil - -This package is for creating and controlling docker containers for testing -runsc, gVisor's docker/kubernetes binary. A simple test may look like: - -``` - func TestSuperCool(t *testing.T) { - ctx := context.Background() - c := dockerutil.MakeContainer(ctx, t) - got, err := c.Run(ctx, dockerutil.RunOpts{ - Image: "basic/alpine" - }, "echo", "super cool") - if err != nil { - t.Fatalf("err was not nil: %v", err) - } - want := "super cool" - if !strings.Contains(got, want){ - t.Fatalf("want: %s, got: %s", want, got) - } - } -``` - -For further examples, see many of our end to end tests elsewhere in the repo, -such as those in //test/e2e or benchmarks at //test/benchmarks. - -dockerutil uses the "official" docker golang api, which is -[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil -is a thin wrapper around this API, allowing desired new use cases to be easily -implemented. - -## Profiling - -dockerutil is capable of generating profiles. Currently, the only option is to -use pprof profiles generated by `runsc debug`. The profiler will generate Block, -CPU, Heap, Goroutine, and Mutex profiles. To generate profiles: - -* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc - ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or - `--vfs2`. -* Restart docker: `sudo service docker restart` - -To run and generate CPU profiles run: - -``` -make sudo TARGETS=//path/to:target \ - ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" -``` - -Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` - -Container name in most tests and benchmarks in gVisor is usually the test name -and some random characters like so: -`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2` - -Profiling requires root as runsc debug inspects running containers in /var/run -among other things. - -### Writing for Profiling - -The below shows an example of using profiles with dockerutil. - -``` -func TestSuperCool(t *testing.T){ - ctx := context.Background() - // profiled and using runtime from dockerutil.runtime flag - profiled := MakeContainer() - - // not profiled and using runtime runc - native := MakeNativeContainer() - - err := profiled.Spawn(ctx, RunOpts{ - Image: "some/image", - }, "sleep", "100000") - // profiling has begun here - ... - expensive setup that I don't want to profile. - ... - profiled.RestartProfiles() - // profiled activity -} -``` - -In the above example, `profiled` would be profiled and `native` would not. The -call to `RestartProfiles()` restarts the clock on profiling. This is useful if -the main activity being tested is done with `docker exec` or `container.Spawn()` -followed by one or more `container.Exec()` calls. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go deleted file mode 100644 index 06152a444..000000000 --- a/pkg/test/dockerutil/container.go +++ /dev/null @@ -1,548 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dockerutil - -import ( - "bytes" - "context" - "errors" - "fmt" - "io/ioutil" - "net" - "os" - "path" - "path/filepath" - "regexp" - "strconv" - "strings" - "time" - - "github.com/docker/docker/api/types" - "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/mount" - "github.com/docker/docker/api/types/network" - "github.com/docker/docker/client" - "github.com/docker/docker/pkg/stdcopy" - "github.com/docker/go-connections/nat" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -// Container represents a Docker Container allowing -// user to configure and control as one would with the 'docker' -// client. Container is backed by the offical golang docker API. -// See: https://pkg.go.dev/github.com/docker/docker. -type Container struct { - Name string - runtime string - - logger testutil.Logger - client *client.Client - id string - mounts []mount.Mount - links []string - copyErr error - cleanups []func() - - // profile is the profiling hook associated with this container. - profile *profile -} - -// RunOpts are options for running a container. -type RunOpts struct { - // Image is the image relative to images/. This will be mangled - // appropriately, to ensure that only first-party images are used. - Image string - - // Memory is the memory limit in bytes. - Memory int - - // Cpus in which to allow execution. ("0", "1", "0-2"). - CpusetCpus string - - // Ports are the ports to be allocated. - Ports []int - - // WorkDir sets the working directory. - WorkDir string - - // ReadOnly sets the read-only flag. - ReadOnly bool - - // Env are additional environment variables. - Env []string - - // User is the user to use. - User string - - // Privileged enables privileged mode. - Privileged bool - - // CapAdd are the extra set of capabilities to add. - CapAdd []string - - // CapDrop are the extra set of capabilities to drop. - CapDrop []string - - // Mounts is the list of directories/files to be mounted inside the container. - Mounts []mount.Mount - - // Links is the list of containers to be connected to the container. - Links []string -} - -func makeContainer(ctx context.Context, logger testutil.Logger, runtime string) *Container { - // Slashes are not allowed in container names. - name := testutil.RandomID(logger.Name()) - name = strings.ReplaceAll(name, "/", "-") - client, err := client.NewClientWithOpts(client.FromEnv) - if err != nil { - return nil - } - client.NegotiateAPIVersion(ctx) - return &Container{ - logger: logger, - Name: name, - runtime: runtime, - client: client, - } -} - -// MakeContainer constructs a suitable Container object. -// -// The runtime used is determined by the runtime flag. -// -// Containers will check flags for profiling requests. -func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { - return makeContainer(ctx, logger, *runtime) -} - -// MakeNativeContainer constructs a suitable Container object. -// -// The runtime used will be the system default. -// -// Native containers aren't profiled. -func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { - return makeContainer(ctx, logger, "" /*runtime*/) -} - -// Spawn is analogous to 'docker run -d'. -func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { - if err := c.create(ctx, r.Image, c.config(r, args), c.hostConfig(r), nil); err != nil { - return err - } - return c.Start(ctx) -} - -// SpawnProcess is analogous to 'docker run -it'. It returns a process -// which represents the root process. -func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) { - config, hostconf, netconf := c.ConfigsFrom(r, args...) - config.Tty = true - config.OpenStdin = true - - if err := c.CreateFrom(ctx, r.Image, config, hostconf, netconf); err != nil { - return Process{}, err - } - - // Open a connection to the container for parsing logs and for TTY. - stream, err := c.client.ContainerAttach(ctx, c.id, - types.ContainerAttachOptions{ - Stream: true, - Stdin: true, - Stdout: true, - Stderr: true, - }) - if err != nil { - return Process{}, fmt.Errorf("connect failed container id %s: %v", c.id, err) - } - - c.cleanups = append(c.cleanups, func() { stream.Close() }) - - if err := c.Start(ctx); err != nil { - return Process{}, err - } - - return Process{container: c, conn: stream}, nil -} - -// Run is analogous to 'docker run'. -func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { - if err := c.create(ctx, r.Image, c.config(r, args), c.hostConfig(r), nil); err != nil { - return "", err - } - - if err := c.Start(ctx); err != nil { - return "", err - } - - if err := c.Wait(ctx); err != nil { - return "", err - } - - return c.Logs(ctx) -} - -// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom' -// and Start. -func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) { - return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{} -} - -// MakeLink formats a link to add to a RunOpts. -func (c *Container) MakeLink(target string) string { - return fmt.Sprintf("%s:%s", c.Name, target) -} - -// CreateFrom creates a container from the given configs. -func (c *Container) CreateFrom(ctx context.Context, profileImage string, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { - return c.create(ctx, profileImage, conf, hostconf, netconf) -} - -// Create is analogous to 'docker create'. -func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { - return c.create(ctx, r.Image, c.config(r, args), c.hostConfig(r), nil) -} - -func (c *Container) create(ctx context.Context, profileImage string, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { - if c.runtime != "" && c.runtime != "runc" { - // Use the image name as provided here; which normally represents the - // unmodified "basic/alpine" image name. This should be easy to grok. - c.profileInit(profileImage) - } - cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) - if err != nil { - return err - } - c.id = cont.ID - return nil -} - -func (c *Container) config(r RunOpts, args []string) *container.Config { - ports := nat.PortSet{} - for _, p := range r.Ports { - port := nat.Port(fmt.Sprintf("%d", p)) - ports[port] = struct{}{} - } - env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) - - return &container.Config{ - Image: testutil.ImageByName(r.Image), - Cmd: args, - ExposedPorts: ports, - Env: env, - WorkingDir: r.WorkDir, - User: r.User, - } -} - -func (c *Container) hostConfig(r RunOpts) *container.HostConfig { - c.mounts = append(c.mounts, r.Mounts...) - - return &container.HostConfig{ - Runtime: c.runtime, - Mounts: c.mounts, - PublishAllPorts: true, - Links: r.Links, - CapAdd: r.CapAdd, - CapDrop: r.CapDrop, - Privileged: r.Privileged, - ReadonlyRootfs: r.ReadOnly, - Resources: container.Resources{ - Memory: int64(r.Memory), // In bytes. - CpusetCpus: r.CpusetCpus, - }, - } -} - -// Start is analogous to 'docker start'. -func (c *Container) Start(ctx context.Context) error { - if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { - return fmt.Errorf("ContainerStart failed: %v", err) - } - - if c.profile != nil { - if err := c.profile.Start(c); err != nil { - c.logger.Logf("profile.Start failed: %v", err) - } - } - - return nil -} - -// Stop is analogous to 'docker stop'. -func (c *Container) Stop(ctx context.Context) error { - return c.client.ContainerStop(ctx, c.id, nil) -} - -// Pause is analogous to'docker pause'. -func (c *Container) Pause(ctx context.Context) error { - return c.client.ContainerPause(ctx, c.id) -} - -// Unpause is analogous to 'docker unpause'. -func (c *Container) Unpause(ctx context.Context) error { - return c.client.ContainerUnpause(ctx, c.id) -} - -// Checkpoint is analogous to 'docker checkpoint'. -func (c *Container) Checkpoint(ctx context.Context, name string) error { - return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true}) -} - -// Restore is analogous to 'docker start --checkname [name]'. -func (c *Container) Restore(ctx context.Context, name string) error { - return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name}) -} - -// Logs is analogous 'docker logs'. -func (c *Container) Logs(ctx context.Context) (string, error) { - var out bytes.Buffer - err := c.logs(ctx, &out, &out) - return out.String(), err -} - -func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error { - opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true} - writer, err := c.client.ContainerLogs(ctx, c.id, opts) - if err != nil { - return err - } - defer writer.Close() - _, err = stdcopy.StdCopy(stdout, stderr, writer) - - return err -} - -// ID returns the container id. -func (c *Container) ID() string { - return c.id -} - -// SandboxPid returns the container's pid. -func (c *Container) SandboxPid(ctx context.Context) (int, error) { - resp, err := c.client.ContainerInspect(ctx, c.id) - if err != nil { - return -1, err - } - return resp.ContainerJSONBase.State.Pid, nil -} - -// ErrNoIP indicates that no IP address is available. -var ErrNoIP = errors.New("no IP available") - -// FindIP returns the IP address of the container. -func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) { - resp, err := c.client.ContainerInspect(ctx, c.id) - if err != nil { - return nil, err - } - - var ip net.IP - if ipv6 { - ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.GlobalIPv6Address) - } else { - ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress) - } - if ip == nil { - return net.IP{}, ErrNoIP - } - return ip, nil -} - -// FindPort returns the host port that is mapped to 'sandboxPort'. -func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) { - desc, err := c.client.ContainerInspect(ctx, c.id) - if err != nil { - return -1, fmt.Errorf("error retrieving port: %v", err) - } - - format := fmt.Sprintf("%d/tcp", sandboxPort) - ports, ok := desc.NetworkSettings.Ports[nat.Port(format)] - if !ok { - return -1, fmt.Errorf("error retrieving port: %v", err) - - } - - port, err := strconv.Atoi(ports[0].HostPort) - if err != nil { - return -1, fmt.Errorf("error parsing port %q: %v", port, err) - } - return port, nil -} - -// CopyFiles copies in and mounts the given files. They are always ReadOnly. -func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { - dir, err := ioutil.TempDir("", c.Name) - if err != nil { - c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) - return - } - c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) }) - if err := os.Chmod(dir, 0755); err != nil { - c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) - return - } - for _, name := range sources { - src := name - if !filepath.IsAbs(src) { - src, err = testutil.FindFile(name) - if err != nil { - c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %w", name, err) - return - } - } - dst := path.Join(dir, path.Base(name)) - if err := testutil.Copy(src, dst); err != nil { - c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) - return - } - c.logger.Logf("copy: %s -> %s", src, dst) - } - opts.Mounts = append(opts.Mounts, mount.Mount{ - Type: mount.TypeBind, - Source: dir, - Target: target, - ReadOnly: false, - }) -} - -// Status inspects the container returns its status. -func (c *Container) Status(ctx context.Context) (types.ContainerState, error) { - resp, err := c.client.ContainerInspect(ctx, c.id) - if err != nil { - return types.ContainerState{}, err - } - return *resp.State, err -} - -// Wait waits for the container to exit. -func (c *Container) Wait(ctx context.Context) error { - defer c.stopProfiling() - statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) - select { - case err := <-errChan: - return err - case res := <-statusChan: - if res.StatusCode != 0 { - var msg string - if res.Error != nil { - msg = res.Error.Message - } - return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg) - } - return nil - } -} - -// WaitTimeout waits for the container to exit with a timeout. -func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) - select { - case <-ctx.Done(): - if ctx.Err() == context.DeadlineExceeded { - return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds()) - } - return nil - case err := <-errChan: - return err - case <-statusChan: - return nil - } -} - -// WaitForOutput searches container logs for pattern and returns or timesout. -func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) { - matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout) - if err != nil { - return "", err - } - if len(matches) == 0 { - return "", fmt.Errorf("didn't find pattern %s logs", pattern) - } - return matches[0], nil -} - -// WaitForOutputSubmatch searches container logs for the given -// pattern or times out. It returns any regexp submatches as well. -func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) { - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - re := regexp.MustCompile(pattern) - for { - logs, err := c.Logs(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get logs: %v logs: %s", err, logs) - } - if matches := re.FindStringSubmatch(logs); matches != nil { - return matches, nil - } - time.Sleep(50 * time.Millisecond) - } -} - -// stopProfiling stops profiling. -func (c *Container) stopProfiling() { - if c.profile != nil { - if err := c.profile.Stop(c); err != nil { - // This most likely means that the runtime for the container - // was too short to connect and actually get a profile. - c.logger.Logf("warning: profile.Stop failed: %v", err) - } - } -} - -// Kill kills the container. -func (c *Container) Kill(ctx context.Context) error { - defer c.stopProfiling() - return c.client.ContainerKill(ctx, c.id, "") -} - -// Remove is analogous to 'docker rm'. -func (c *Container) Remove(ctx context.Context) error { - // Remove the image. - remove := types.ContainerRemoveOptions{ - RemoveVolumes: c.mounts != nil, - RemoveLinks: c.links != nil, - Force: true, - } - return c.client.ContainerRemove(ctx, c.Name, remove) -} - -// CleanUp kills and deletes the container (best effort). -func (c *Container) CleanUp(ctx context.Context) { - // Execute all cleanups. We execute cleanups here to close any - // open connections to the container before closing. Open connections - // can cause Kill and Remove to hang. - for _, c := range c.cleanups { - c() - } - c.cleanups = nil - - // Kill the container. - if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { - // Just log; can't do anything here. - c.logger.Logf("error killing container %q: %v", c.Name, err) - } - - // Remove the image. - if err := c.Remove(ctx); err != nil { - c.logger.Logf("error removing container %q: %v", c.Name, err) - } - - // Forget all mounts. - c.mounts = nil -} diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go deleted file mode 100644 index a40005799..000000000 --- a/pkg/test/dockerutil/dockerutil.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package dockerutil is a collection of utility functions. -package dockerutil - -import ( - "encoding/json" - "flag" - "fmt" - "io" - "io/ioutil" - "log" - "os/exec" - "regexp" - "strconv" - "time" - - "gvisor.dev/gvisor/pkg/test/testutil" -) - -var ( - // runtime is the runtime to use for tests. This will be applied to all - // containers. Note that the default here ("runsc") corresponds to the - // default used by the installations. This is important, because the - // default installer for vm_tests (in tools/installers:head, invoked - // via tools/vm:defs.bzl) will install with this name. So without - // changing anything, tests should have a runsc runtime available to - // them. Otherwise installers should update the existing runtime - // instead of installing a new one. - runtime = flag.String("runtime", "runsc", "specify which runtime to use") - - // config is the default Docker daemon configuration path. - config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") - - // The following flags are for the "pprof" profiler tool. - - // pprofBaseDir allows the user to change the directory to which profiles are - // written. By default, profiles will appear under: - // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. - pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") - pprofDuration = flag.Duration("pprof-duration", time.Hour, "profiling duration (automatically stopped at container exit)") - - // The below flags enable each type of profile. Multiple profiles can be - // enabled for each run. The profile will be collected from the start. - pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") - pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") - pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") - pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") -) - -// EnsureSupportedDockerVersion checks if correct docker is installed. -// -// This logs directly to stderr, as it is typically called from a Main wrapper. -func EnsureSupportedDockerVersion() { - cmd := exec.Command("docker", "version") - out, err := cmd.CombinedOutput() - if err != nil { - log.Fatalf("error running %q: %v", "docker version", err) - } - re := regexp.MustCompile(`Version:\s+(\d+)\.(\d+)\.\d.*`) - matches := re.FindStringSubmatch(string(out)) - if len(matches) != 3 { - log.Fatalf("Invalid docker output: %s", out) - } - major, _ := strconv.Atoi(matches[1]) - minor, _ := strconv.Atoi(matches[2]) - if major < 17 || (major == 17 && minor < 9) { - log.Fatalf("Docker version 17.09.0 or greater is required, found: %02d.%02d", major, minor) - } -} - -// RuntimePath returns the binary path for the current runtime. -func RuntimePath() (string, error) { - rs, err := runtimeMap() - if err != nil { - return "", err - } - - p, ok := rs["path"].(string) - if !ok { - // The runtime does not declare a path. - return "", fmt.Errorf("runtime does not declare a path: %v", rs) - } - return p, nil -} - -// UsingVFS2 returns true if the 'runtime' has the vfs2 flag set. -// TODO(gvisor.dev/issue/1624): Remove. -func UsingVFS2() (bool, error) { - rMap, err := runtimeMap() - if err != nil { - return false, err - } - - list, ok := rMap["runtimeArgs"].([]interface{}) - if !ok { - return false, fmt.Errorf("unexpected format: %v", rMap) - } - - for _, element := range list { - if element == "--vfs2" { - return true, nil - } - } - return false, nil -} - -func runtimeMap() (map[string]interface{}, error) { - // Read the configuration data; the file must exist. - configBytes, err := ioutil.ReadFile(*config) - if err != nil { - return nil, err - } - - // Unmarshal the configuration. - c := make(map[string]interface{}) - if err := json.Unmarshal(configBytes, &c); err != nil { - return nil, err - } - - // Decode the expected configuration. - r, ok := c["runtimes"] - if !ok { - return nil, fmt.Errorf("no runtimes declared: %v", c) - } - rs, ok := r.(map[string]interface{}) - if !ok { - // The runtimes are not a map. - return nil, fmt.Errorf("unexpected format: %v", rs) - } - r, ok = rs[*runtime] - if !ok { - // The expected runtime is not declared. - return nil, fmt.Errorf("runtime %q not found: %v", *runtime, rs) - } - rs, ok = r.(map[string]interface{}) - if !ok { - // The runtime is not a map. - return nil, fmt.Errorf("unexpected format: %v", r) - } - return rs, nil -} - -// Save exports a container image to the given Writer. -// -// Note that the writer should be actively consuming the output, otherwise it -// is not guaranteed that the Save will make any progress and the call may -// stall indefinitely. -// -// This is called by criutil in order to import imports. -func Save(logger testutil.Logger, image string, w io.Writer) error { - cmd := testutil.Command(logger, "docker", "save", testutil.ImageByName(image)) - cmd.Stdout = w // Send directly to the writer. - return cmd.Run() -} - -// Runtime returns the value of the flag runtime. -func Runtime() string { - return *runtime -} diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go deleted file mode 100644 index bf968acec..000000000 --- a/pkg/test/dockerutil/exec.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dockerutil - -import ( - "bytes" - "context" - "fmt" - "time" - - "github.com/docker/docker/api/types" - "github.com/docker/docker/pkg/stdcopy" -) - -// ExecOpts holds arguments for Exec calls. -type ExecOpts struct { - // Env are additional environment variables. - Env []string - - // Privileged enables privileged mode. - Privileged bool - - // User is the user to use. - User string - - // Enables Tty and stdin for the created process. - UseTTY bool - - // WorkDir is the working directory of the process. - WorkDir string -} - -// Exec creates a process inside the container. -func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) { - p, err := c.doExec(ctx, opts, args) - if err != nil { - return "", err - } - - if exitStatus, err := p.WaitExitStatus(ctx); err != nil { - return "", err - } else if exitStatus != 0 { - out, _ := p.Logs() - return out, fmt.Errorf("process terminated with status: %d", exitStatus) - } - - return p.Logs() -} - -// ExecProcess creates a process inside the container and returns a process struct -// for the caller to use. -func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) { - return c.doExec(ctx, opts, args) -} - -func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) { - config := c.execConfig(r, args) - resp, err := c.client.ContainerExecCreate(ctx, c.id, config) - if err != nil { - return Process{}, fmt.Errorf("exec create failed with err: %v", err) - } - - hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{}) - if err != nil { - return Process{}, fmt.Errorf("exec attach failed with err: %v", err) - } - - return Process{ - container: c, - execid: resp.ID, - conn: hijack, - }, nil -} - -func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig { - env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) - return types.ExecConfig{ - AttachStdin: r.UseTTY, - AttachStderr: true, - AttachStdout: true, - Cmd: cmd, - Privileged: r.Privileged, - WorkingDir: r.WorkDir, - Env: env, - Tty: r.UseTTY, - User: r.User, - } - -} - -// Process represents a containerized process. -type Process struct { - container *Container - execid string - conn types.HijackedResponse -} - -// Write writes buf to the process's stdin. -func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) { - p.conn.Conn.SetDeadline(time.Now().Add(timeout)) - return p.conn.Conn.Write(buf) -} - -// Read returns process's stdout and stderr. -func (p *Process) Read() (string, string, error) { - var stdout, stderr bytes.Buffer - if err := p.read(&stdout, &stderr); err != nil { - return "", "", err - } - return stdout.String(), stderr.String(), nil -} - -// Logs returns combined stdout/stderr from the process. -func (p *Process) Logs() (string, error) { - var out bytes.Buffer - if err := p.read(&out, &out); err != nil { - return "", err - } - return out.String(), nil -} - -func (p *Process) read(stdout, stderr *bytes.Buffer) error { - _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader) - return err -} - -// ExitCode returns the process's exit code. -func (p *Process) ExitCode(ctx context.Context) (int, error) { - _, exitCode, err := p.runningExitCode(ctx) - return exitCode, err -} - -// IsRunning checks if the process is running. -func (p *Process) IsRunning(ctx context.Context) (bool, error) { - running, _, err := p.runningExitCode(ctx) - return running, err -} - -// WaitExitStatus until process completes and returns exit status. -func (p *Process) WaitExitStatus(ctx context.Context) (int, error) { - waitChan := make(chan (int)) - errChan := make(chan (error)) - - go func() { - for { - running, exitcode, err := p.runningExitCode(ctx) - if err != nil { - errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name) - } - if !running { - waitChan <- exitcode - } - time.Sleep(time.Millisecond * 500) - } - }() - - select { - case ws := <-waitChan: - return ws, nil - case err := <-errChan: - return -1, err - } -} - -// runningExitCode collects if the process is running and the exit code. -// The exit code is only valid if the process has exited. -func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) { - // If execid is not empty, this is a execed process. - if p.execid != "" { - status, err := p.container.client.ContainerExecInspect(ctx, p.execid) - return status.Running, status.ExitCode, err - } - // else this is the root process. - status, err := p.container.Status(ctx) - return status.Running, status.ExitCode, err -} diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go deleted file mode 100644 index dbe17fa5e..000000000 --- a/pkg/test/dockerutil/network.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dockerutil - -import ( - "context" - "net" - - "github.com/docker/docker/api/types" - "github.com/docker/docker/api/types/network" - "github.com/docker/docker/client" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -// Network is a docker network. -type Network struct { - client *client.Client - id string - logger testutil.Logger - Name string - containers []*Container - Subnet *net.IPNet -} - -// NewNetwork sets up the struct for a Docker network. Names of networks -// will be unique. -func NewNetwork(ctx context.Context, logger testutil.Logger) *Network { - client, err := client.NewClientWithOpts(client.FromEnv) - if err != nil { - logger.Logf("create client failed with: %v", err) - return nil - } - client.NegotiateAPIVersion(ctx) - - return &Network{ - logger: logger, - Name: testutil.RandomID(logger.Name()), - client: client, - } -} - -func (n *Network) networkCreate() types.NetworkCreate { - - var subnet string - if n.Subnet != nil { - subnet = n.Subnet.String() - } - - ipam := network.IPAM{ - Config: []network.IPAMConfig{{ - Subnet: subnet, - }}, - } - - return types.NetworkCreate{ - CheckDuplicate: true, - IPAM: &ipam, - } -} - -// Create is analogous to 'docker network create'. -func (n *Network) Create(ctx context.Context) error { - - opts := n.networkCreate() - resp, err := n.client.NetworkCreate(ctx, n.Name, opts) - if err != nil { - return err - } - n.id = resp.ID - return nil -} - -// Connect is analogous to 'docker network connect' with the arguments provided. -func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error { - settings := network.EndpointSettings{ - IPAMConfig: &network.EndpointIPAMConfig{ - IPv4Address: ipv4, - IPv6Address: ipv6, - }, - } - err := n.client.NetworkConnect(ctx, n.id, container.id, &settings) - if err == nil { - n.containers = append(n.containers, container) - } - return err -} - -// Inspect returns this network's info. -func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { - return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) -} - -// Cleanup cleans up the docker network. -func (n *Network) Cleanup(ctx context.Context) error { - n.containers = nil - - return n.client.NetworkRemove(ctx, n.id) -} diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go deleted file mode 100644 index 12fe98b16..000000000 --- a/pkg/test/dockerutil/profile.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dockerutil - -import ( - "context" - "fmt" - "os" - "os/exec" - "path/filepath" - "time" - - "golang.org/x/sys/unix" -) - -// profile represents profile-like operations on a container. -// -// It is meant to be added to containers such that the container type calls -// the profile during its lifecycle. Standard implementations are below. - -// profile is for running profiles with 'runsc debug'. -type profile struct { - BasePath string - Types []string - Duration time.Duration - cmd *exec.Cmd -} - -// profileInit initializes a profile object, if required. -// -// N.B. The profiling filename initialized here will use the *image* -// name, and not the unique container name. This is intentional. Most -// of the time, profiling will be used for benchmarks. Benchmarks will -// be run iteratively until a sufficiently large N is reached. It is -// useful in this context to overwrite previous runs, and generate a -// single profile result for the final test. -func (c *Container) profileInit(image string) { - if !*pprofBlock && !*pprofCPU && !*pprofMutex && !*pprofHeap { - return // Nothing to do. - } - c.profile = &profile{ - BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.logger.Name(), image), - Duration: *pprofDuration, - } - if *pprofCPU { - c.profile.Types = append(c.profile.Types, "cpu") - } - if *pprofHeap { - c.profile.Types = append(c.profile.Types, "heap") - } - if *pprofMutex { - c.profile.Types = append(c.profile.Types, "mutex") - } - if *pprofBlock { - c.profile.Types = append(c.profile.Types, "block") - } -} - -// createProcess creates the collection process. -func (p *profile) createProcess(c *Container) error { - // Ensure our directory exists. - if err := os.MkdirAll(p.BasePath, 0755); err != nil { - return err - } - - // Find the runtime to invoke. - path, err := RuntimePath() - if err != nil { - return fmt.Errorf("failed to get runtime path: %v", err) - } - - // The root directory of this container's runtime. - rootDir := fmt.Sprintf("/var/run/docker/runtime-%s/moby", c.runtime) - if _, err := os.Stat(rootDir); os.IsNotExist(err) { - // In docker v20+, due to https://github.com/moby/moby/issues/42345 the - // rootDir seems to always be the following. - rootDir = "/var/run/docker/runtime-runc/moby" - } - - // Format is `runsc --root=rootDir debug --profile-*=file --duration=24h containerID`. - args := []string{fmt.Sprintf("--root=%s", rootDir), "debug"} - for _, profileArg := range p.Types { - outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg)) - args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath)) - } - args = append(args, fmt.Sprintf("--duration=%s", p.Duration)) // Or until container exits. - args = append(args, fmt.Sprintf("--delay=%s", p.Duration)) // Ditto. - args = append(args, c.ID()) - - // Best effort wait until container is running. - for now := time.Now(); time.Since(now) < 5*time.Second; { - if status, err := c.Status(context.Background()); err != nil { - return fmt.Errorf("failed to get status with: %v", err) - } else if status.Running { - break - } - time.Sleep(100 * time.Millisecond) - } - p.cmd = exec.Command(path, args...) - p.cmd.Stderr = os.Stderr // Pass through errors. - if err := p.cmd.Start(); err != nil { - return fmt.Errorf("start process failed: %v", err) - } - - return nil -} - -// killProcess kills the process, if running. -func (p *profile) killProcess() error { - if p.cmd != nil && p.cmd.Process != nil { - return p.cmd.Process.Signal(unix.SIGTERM) - } - return nil -} - -// waitProcess waits for the process, if running. -func (p *profile) waitProcess() error { - defer func() { p.cmd = nil }() - if p.cmd != nil { - return p.cmd.Wait() - } - return nil -} - -// Start is called when profiling is started. -func (p *profile) Start(c *Container) error { - return p.createProcess(c) -} - -// Stop is called when profiling is started. -func (p *profile) Stop(c *Container) error { - killErr := p.killProcess() - waitErr := p.waitProcess() - if waitErr != nil && killErr != nil { - return killErr - } - return waitErr // Ignore okay wait, err kill. -} diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go deleted file mode 100644 index 4fe9ce15c..000000000 --- a/pkg/test/dockerutil/profile_test.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dockerutil - -import ( - "context" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "testing" - "time" -) - -type testCase struct { - name string - profile profile - expectedFiles []string -} - -func TestProfile(t *testing.T) { - // Basepath and expected file names for each type of profile. - tmpDir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatalf("unable to create temporary directory: %v", err) - } - defer os.RemoveAll(tmpDir) - - // All expected names. - basePath := tmpDir - block := "block.pprof" - cpu := "cpu.pprof" - heap := "heap.pprof" - mutex := "mutex.pprof" - - testCases := []testCase{ - { - name: "One", - profile: profile{ - BasePath: basePath, - Types: []string{"cpu"}, - Duration: 2 * time.Second, - }, - expectedFiles: []string{cpu}, - }, - { - name: "All", - profile: profile{ - BasePath: basePath, - Types: []string{"block", "cpu", "heap", "mutex"}, - Duration: 2 * time.Second, - }, - expectedFiles: []string{block, cpu, heap, mutex}, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - c := MakeContainer(ctx, t) - - // Set basepath to include the container name so there are no conflicts. - localProfile := tc.profile // Copy it. - localProfile.BasePath = filepath.Join(localProfile.BasePath, tc.name) - - // Set directly on the container, to avoid flags. - c.profile = &localProfile - - func() { - defer c.CleanUp(ctx) - - // Start a container. - if err := c.Spawn(ctx, RunOpts{ - Image: "basic/alpine", - }, "sleep", "1000"); err != nil { - t.Fatalf("run failed with: %v", err) - } - - if status, err := c.Status(context.Background()); !status.Running { - t.Fatalf("container is not yet running: %+v err: %v", status, err) - } - - // End early if the expected files exist and have data. - for start := time.Now(); time.Since(start) < localProfile.Duration; time.Sleep(100 * time.Millisecond) { - if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err == nil { - break - } - } - }() - - // Check all expected files exist and have data. - if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err != nil { - t.Fatalf(err.Error()) - } - }) - } -} - -func checkFiles(basePath string, expectedFiles []string) error { - for _, file := range expectedFiles { - stat, err := os.Stat(filepath.Join(basePath, file)) - if err != nil { - return fmt.Errorf("stat failed with: %v", err) - } else if stat.Size() < 1 { - return fmt.Errorf("file not written to: %+v", stat) - } - } - return nil -} - -func TestMain(m *testing.M) { - EnsureSupportedDockerVersion() - os.Exit(m.Run()) -} diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD deleted file mode 100644 index 7ff13cf12..000000000 --- a/pkg/test/testutil/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "sh.go", - "testutil.go", - "testutil_runfiles.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/sentry/watchdog", - "//pkg/sync", - "//runsc/config", - "//runsc/specutils", - "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_kr_pty//:go_default_library", - "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/test/testutil/sh.go b/pkg/test/testutil/sh.go deleted file mode 100644 index cd5b0557a..000000000 --- a/pkg/test/testutil/sh.go +++ /dev/null @@ -1,515 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "bytes" - "context" - "fmt" - "io" - "os" - "os/exec" - "strings" - "time" - - "github.com/kr/pty" - "golang.org/x/sys/unix" -) - -// Prompt is used as shell prompt. -// It is meant to be unique enough to not be seen in command outputs. -const Prompt = "PROMPT> " - -// Simplistic shell string escape. -func shellEscape(s string) string { - // specialChars is used to determine whether s needs quoting at all. - const specialChars = "\\'\"`${[|&;<>()*?! \t\n" - // If s needs quoting, escapedChars is the set of characters that are - // escaped with a backslash. - const escapedChars = "\\\"$`" - if len(s) == 0 { - return "''" - } - if !strings.ContainsAny(s, specialChars) { - return s - } - var b bytes.Buffer - b.WriteString("\"") - for _, c := range s { - if strings.ContainsAny(string(c), escapedChars) { - b.WriteString("\\") - } - b.WriteRune(c) - } - b.WriteString("\"") - return b.String() -} - -type byteOrError struct { - b byte - err error -} - -// Shell manages a /bin/sh invocation with convenience functions to handle I/O. -// The shell is run in its own interactive TTY and should present its prompt. -type Shell struct { - // cmd is a reference to the underlying sh process. - cmd *exec.Cmd - // cmdFinished is closed when cmd exits. - cmdFinished chan struct{} - - // echo is whether the shell will echo input back to us. - // This helps setting expectations of getting feedback of written bytes. - echo bool - // Control characters we expect to see in the shell. - controlCharIntr string - controlCharEOF string - - // ptyMaster and ptyReplica are the TTY pair associated with the shell. - ptyMaster *os.File - ptyReplica *os.File - // readCh is a channel where everything read from ptyMaster is written. - readCh chan byteOrError - - // logger is used for logging. It may be nil. - logger Logger -} - -// cleanup kills the shell process and closes the TTY. -// Users of this library get a reference to this function with NewShell. -func (s *Shell) cleanup() { - s.logf("cleanup", "Shell cleanup started.") - if s.cmd.ProcessState == nil { - if err := s.cmd.Process.Kill(); err != nil { - s.logf("cleanup", "cannot kill shell process: %v", err) - } - // We don't log the error returned by Wait because the monitorExit - // goroutine will already do so. - s.cmd.Wait() - } - s.ptyReplica.Close() - s.ptyMaster.Close() - // Wait for monitorExit goroutine to write exit status to the debug log. - <-s.cmdFinished - // Empty out everything in the readCh, but don't wait too long for it. - var extraBytes bytes.Buffer - unreadTimeout := time.After(100 * time.Millisecond) -unreadLoop: - for { - select { - case r, ok := <-s.readCh: - if !ok { - break unreadLoop - } else if r.err == nil { - extraBytes.WriteByte(r.b) - } - case <-unreadTimeout: - break unreadLoop - } - } - if extraBytes.Len() > 0 { - s.logIO("unread", extraBytes.Bytes(), nil) - } - s.logf("cleanup", "Shell cleanup complete.") -} - -// logIO logs byte I/O to both standard logging and the test log, if provided. -func (s *Shell) logIO(prefix string, b []byte, err error) { - var sb strings.Builder - if len(b) > 0 { - sb.WriteString(fmt.Sprintf("%q", b)) - } else { - sb.WriteString("(nothing)") - } - if err != nil { - sb.WriteString(fmt.Sprintf(" [error: %v]", err)) - } - s.logf(prefix, "%s", sb.String()) -} - -// logf logs something to both standard logging and the test log, if provided. -func (s *Shell) logf(prefix, format string, values ...interface{}) { - if s.logger != nil { - s.logger.Logf("[%s] %s", prefix, fmt.Sprintf(format, values...)) - } -} - -// monitorExit waits for the shell process to exit and logs the exit result. -func (s *Shell) monitorExit() { - if err := s.cmd.Wait(); err != nil { - s.logf("cmd", "shell process terminated: %v", err) - } else { - s.logf("cmd", "shell process terminated successfully") - } - close(s.cmdFinished) -} - -// reader continuously reads the shell output and populates readCh. -func (s *Shell) reader(ctx context.Context) { - b := make([]byte, 4096) - defer close(s.readCh) - for { - select { - case <-s.cmdFinished: - // Shell process terminated; stop trying to read. - return - case <-ctx.Done(): - // Shell process will also have terminated in this case; - // stop trying to read. - // We don't print an error here because doing so would print this in the - // normal case where the context passed to NewShell is canceled at the - // end of a successful test. - return - default: - // Shell still running, try reading. - } - if got, err := s.ptyMaster.Read(b); err != nil { - s.readCh <- byteOrError{err: err} - if err == io.EOF { - return - } - } else { - for i := 0; i < got; i++ { - s.readCh <- byteOrError{b: b[i]} - } - } - } -} - -// readByte reads a single byte, respecting the context. -func (s *Shell) readByte(ctx context.Context) (byte, error) { - select { - case <-ctx.Done(): - return 0, ctx.Err() - case r := <-s.readCh: - return r.b, r.err - } -} - -// readLoop reads as many bytes as possible until the context expires, b is -// full, or a short time passes. It returns how many bytes it has successfully -// read. -func (s *Shell) readLoop(ctx context.Context, b []byte) (int, error) { - soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second) - defer soonCancel() - var i int - for i = 0; i < len(b) && soonCtx.Err() == nil; i++ { - next, err := s.readByte(soonCtx) - if err != nil { - if i > 0 { - s.logIO("read", b[:i-1], err) - } else { - s.logIO("read", nil, err) - } - return i, err - } - b[i] = next - } - s.logIO("read", b[:i], soonCtx.Err()) - return i, soonCtx.Err() -} - -// readLine reads a single line. Strips out all \r characters for convenience. -// Upon error, it will still return what it has read so far. -// It will also exit quickly if the line content it has read so far (without a -// line break) matches `prompt`. -func (s *Shell) readLine(ctx context.Context, prompt string) ([]byte, error) { - soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second) - defer soonCancel() - var lineData bytes.Buffer - var b byte - var err error - for soonCtx.Err() == nil && b != '\n' { - b, err = s.readByte(soonCtx) - if err != nil { - data := lineData.Bytes() - s.logIO("read", data, err) - return data, err - } - if b != '\r' { - lineData.WriteByte(b) - } - if bytes.Equal(lineData.Bytes(), []byte(prompt)) { - // Assume that there will not be any further output if we get the prompt. - // This avoids waiting for the read deadline just to read the prompt. - break - } - } - data := lineData.Bytes() - s.logIO("read", data, soonCtx.Err()) - return data, soonCtx.Err() -} - -// Expect verifies that the next `len(want)` bytes we read match `want`. -func (s *Shell) Expect(ctx context.Context, want []byte) error { - errPrefix := fmt.Sprintf("want(%q)", want) - b := make([]byte, len(want)) - got, err := s.readLoop(ctx, b) - if err != nil { - if ctx.Err() != nil { - return fmt.Errorf("%s: context done (%w), got: %q", errPrefix, err, b[:got]) - } - return fmt.Errorf("%s: %w", errPrefix, err) - } - if got < len(want) { - return fmt.Errorf("%s: short read (read %d bytes, expected %d): %q", errPrefix, got, len(want), b[:got]) - } - if !bytes.Equal(b, want) { - return fmt.Errorf("got %q want %q", b, want) - } - return nil -} - -// ExpectString verifies that the next `len(want)` bytes we read match `want`. -func (s *Shell) ExpectString(ctx context.Context, want string) error { - return s.Expect(ctx, []byte(want)) -} - -// ExpectPrompt verifies that the next few bytes we read are the shell prompt. -func (s *Shell) ExpectPrompt(ctx context.Context) error { - return s.ExpectString(ctx, Prompt) -} - -// ExpectEmptyLine verifies that the next few bytes we read are an empty line, -// as defined by any number of carriage or line break characters. -func (s *Shell) ExpectEmptyLine(ctx context.Context) error { - line, err := s.readLine(ctx, Prompt) - if err != nil { - return fmt.Errorf("cannot read line: %w", err) - } - if strings.Trim(string(line), "\r\n") != "" { - return fmt.Errorf("line was not empty: %q", line) - } - return nil -} - -// ExpectLine verifies that the next `len(want)` bytes we read match `want`, -// followed by carriage returns or newline characters. -func (s *Shell) ExpectLine(ctx context.Context, want string) error { - if err := s.ExpectString(ctx, want); err != nil { - return err - } - if err := s.ExpectEmptyLine(ctx); err != nil { - return fmt.Errorf("ExpectLine(%q): no line break: %w", want, err) - } - return nil -} - -// Write writes `b` to the shell and verifies that all of them get written. -func (s *Shell) Write(b []byte) error { - written, err := s.ptyMaster.Write(b) - s.logIO("write", b[:written], err) - if err != nil { - return fmt.Errorf("write(%q): %w", b, err) - } - if written != len(b) { - return fmt.Errorf("write(%q): wrote %d of %d bytes (%q)", b, written, len(b), b[:written]) - } - return nil -} - -// WriteLine writes `line` (to which \n will be appended) to the shell. -// If the shell is in `echo` mode, it will also check that we got these bytes -// back to read. -func (s *Shell) WriteLine(ctx context.Context, line string) error { - if err := s.Write([]byte(line + "\n")); err != nil { - return err - } - if s.echo { - // We expect to see everything we've typed. - if err := s.ExpectLine(ctx, line); err != nil { - return fmt.Errorf("echo: %w", err) - } - } - return nil -} - -// StartCommand is a convenience wrapper for WriteLine that mimics entering a -// command line and pressing Enter. It does some basic shell argument escaping. -func (s *Shell) StartCommand(ctx context.Context, cmd ...string) error { - escaped := make([]string, len(cmd)) - for i, arg := range cmd { - escaped[i] = shellEscape(arg) - } - return s.WriteLine(ctx, strings.Join(escaped, " ")) -} - -// GetCommandOutput gets all following bytes until the prompt is encountered. -// This is useful for matching the output of a command. -// All \r are removed for ease of matching. -func (s *Shell) GetCommandOutput(ctx context.Context) ([]byte, error) { - return s.ReadUntil(ctx, Prompt) -} - -// ReadUntil gets all following bytes until a certain line is encountered. -// This final line is not returned as part of the output, but everything before -// it (including the \n) is included. -// This is useful for matching the output of a command. -// All \r are removed for ease of matching. -func (s *Shell) ReadUntil(ctx context.Context, finalLine string) ([]byte, error) { - var output bytes.Buffer - for ctx.Err() == nil { - line, err := s.readLine(ctx, finalLine) - if err != nil { - return nil, err - } - if bytes.Equal(line, []byte(finalLine)) { - break - } - // readLine ensures that `line` either matches `finalLine` or contains \n. - // Thus we can be confident that `line` has a \n here. - output.Write(line) - } - return output.Bytes(), ctx.Err() -} - -// RunCommand is a convenience wrapper for StartCommand + GetCommandOutput. -func (s *Shell) RunCommand(ctx context.Context, cmd ...string) ([]byte, error) { - if err := s.StartCommand(ctx, cmd...); err != nil { - return nil, err - } - return s.GetCommandOutput(ctx) -} - -// RefreshSTTY interprets output from `stty -a` to check whether we are in echo -// mode and other settings. -// It will assume that any line matching `expectPrompt` means the end of -// the `stty -a` output. -// Why do this rather than using `tcgets`? Because this function can be used in -// conjunction with sub-shell processes that can allocate their own TTYs. -func (s *Shell) RefreshSTTY(ctx context.Context, expectPrompt string) error { - // Temporarily assume we will not get any output. - // If echo is actually on, we'll get the "stty -a" line as if it was command - // output. This is OK because we parse the output generously. - s.echo = false - if err := s.WriteLine(ctx, "stty -a"); err != nil { - return fmt.Errorf("could not run `stty -a`: %w", err) - } - sttyOutput, err := s.ReadUntil(ctx, expectPrompt) - if err != nil { - return fmt.Errorf("cannot get `stty -a` output: %w", err) - } - - // Set default control characters in case we can't see them in the output. - s.controlCharIntr = "^C" - s.controlCharEOF = "^D" - // stty output has two general notations: - // `a = b;` (for control characters), and `option` vs `-option` (for boolean - // options). We parse both kinds here. - // For `a = b;`, `controlChar` contains `a`, and `previousToken` is used to - // set `controlChar` to `previousToken` when we see an "=" token. - var previousToken, controlChar string - for _, token := range strings.Fields(string(sttyOutput)) { - if controlChar != "" { - value := strings.TrimSuffix(token, ";") - switch controlChar { - case "intr": - s.controlCharIntr = value - case "eof": - s.controlCharEOF = value - } - controlChar = "" - } else { - switch token { - case "=": - controlChar = previousToken - case "-echo": - s.echo = false - case "echo": - s.echo = true - } - } - previousToken = token - } - s.logf("stty", "refreshed settings: echo=%v, intr=%q, eof=%q", s.echo, s.controlCharIntr, s.controlCharEOF) - return nil -} - -// sendControlCode sends `code` to the shell and expects to see `repr`. -// If `expectLinebreak` is true, it also expects to see a linebreak. -func (s *Shell) sendControlCode(ctx context.Context, code byte, repr string, expectLinebreak bool) error { - if err := s.Write([]byte{code}); err != nil { - return fmt.Errorf("cannot send %q: %w", code, err) - } - if err := s.ExpectString(ctx, repr); err != nil { - return fmt.Errorf("did not see %s: %w", repr, err) - } - if expectLinebreak { - if err := s.ExpectEmptyLine(ctx); err != nil { - return fmt.Errorf("linebreak after %s: %v", repr, err) - } - } - return nil -} - -// SendInterrupt sends the \x03 (Ctrl+C) control character to the shell. -func (s *Shell) SendInterrupt(ctx context.Context, expectLinebreak bool) error { - return s.sendControlCode(ctx, 0x03, s.controlCharIntr, expectLinebreak) -} - -// SendEOF sends the \x04 (Ctrl+D) control character to the shell. -func (s *Shell) SendEOF(ctx context.Context, expectLinebreak bool) error { - return s.sendControlCode(ctx, 0x04, s.controlCharEOF, expectLinebreak) -} - -// NewShell returns a new managed sh process along with a cleanup function. -// The caller is expected to call this function once it no longer needs the -// shell. -// The optional passed-in logger will be used for logging. -func NewShell(ctx context.Context, logger Logger) (*Shell, func(), error) { - ptyMaster, ptyReplica, err := pty.Open() - if err != nil { - return nil, nil, fmt.Errorf("cannot create PTY: %w", err) - } - cmd := exec.CommandContext(ctx, "/bin/sh", "--noprofile", "--norc", "-i") - cmd.Stdin = ptyReplica - cmd.Stdout = ptyReplica - cmd.Stderr = ptyReplica - cmd.SysProcAttr = &unix.SysProcAttr{ - Setsid: true, - Setctty: true, - Ctty: 0, - } - cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", Prompt)) - if err := cmd.Start(); err != nil { - return nil, nil, fmt.Errorf("cannot start shell: %w", err) - } - s := &Shell{ - cmd: cmd, - cmdFinished: make(chan struct{}), - ptyMaster: ptyMaster, - ptyReplica: ptyReplica, - readCh: make(chan byteOrError, 1<<20), - logger: logger, - } - s.logf("creation", "Shell spawned.") - go s.monitorExit() - go s.reader(ctx) - setupCtx, setupCancel := context.WithTimeout(ctx, 5*time.Second) - defer setupCancel() - // We expect to see the prompt immediately on startup, - // since the shell is started in interactive mode. - if err := s.ExpectPrompt(setupCtx); err != nil { - s.cleanup() - return nil, nil, fmt.Errorf("did not get initial prompt: %w", err) - } - s.logf("creation", "Initial prompt observed.") - // Get initial TTY settings. - if err := s.RefreshSTTY(setupCtx, Prompt); err != nil { - s.cleanup() - return nil, nil, fmt.Errorf("cannot get initial STTY settings: %w", err) - } - return s, s.cleanup, nil -} diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go deleted file mode 100644 index f6a3e34c7..000000000 --- a/pkg/test/testutil/testutil.go +++ /dev/null @@ -1,599 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil contains utility functions for runsc tests. -package testutil - -import ( - "bufio" - "context" - "debug/elf" - "encoding/base32" - "encoding/json" - "flag" - "fmt" - "io" - "io/ioutil" - "log" - "math" - "math/rand" - "net/http" - "os" - "os/exec" - "os/signal" - "path" - "path/filepath" - "strconv" - "strings" - "testing" - "time" - - "github.com/cenkalti/backoff" - specs "github.com/opencontainers/runtime-spec/specs-go" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sentry/watchdog" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/runsc/config" - "gvisor.dev/gvisor/runsc/specutils" -) - -var ( - checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support") - partition = flag.Int("partition", 1, "partition number, this is 1-indexed") - totalPartitions = flag.Int("total_partitions", 1, "total number of partitions") - isRunningWithHostNet = flag.Bool("hostnet", false, "whether test is running with hostnet") -) - -// IsCheckpointSupported returns the relevant command line flag. -func IsCheckpointSupported() bool { - return *checkpoint -} - -// IsRunningWithHostNet returns the relevant command line flag. -func IsRunningWithHostNet() bool { - return *isRunningWithHostNet -} - -// ImageByName mangles the image name used locally. This depends on the image -// build infrastructure in images/ and tools/vm. -func ImageByName(name string) string { - return fmt.Sprintf("gvisor.dev/images/%s", name) -} - -// ConfigureExePath configures the executable for runsc in the test environment. -func ConfigureExePath() error { - path, err := FindFile("runsc/runsc") - if err != nil { - return err - } - specutils.ExePath = path - return nil -} - -// TmpDir returns the absolute path to a writable directory that can be used as -// scratch by the test. -func TmpDir() string { - if dir, ok := os.LookupEnv("TEST_TMPDIR"); ok { - return dir - } - return "/tmp" -} - -// Logger is a simple logging wrapper. -// -// This is designed to be implemented by *testing.T. -type Logger interface { - Name() string - Logf(fmt string, args ...interface{}) -} - -// DefaultLogger logs using the log package. -type DefaultLogger string - -// Name implements Logger.Name. -func (d DefaultLogger) Name() string { - return string(d) -} - -// Logf implements Logger.Logf. -func (d DefaultLogger) Logf(fmt string, args ...interface{}) { - log.Printf(fmt, args...) -} - -// multiLogger logs to multiple Loggers. -type multiLogger []Logger - -// Name implements Logger.Name. -func (m multiLogger) Name() string { - names := make([]string, len(m)) - for i, l := range m { - names[i] = l.Name() - } - return strings.Join(names, "+") -} - -// Logf implements Logger.Logf. -func (m multiLogger) Logf(fmt string, args ...interface{}) { - for _, l := range m { - l.Logf(fmt, args...) - } -} - -// NewMultiLogger returns a new Logger that logs on multiple Loggers. -func NewMultiLogger(loggers ...Logger) Logger { - return multiLogger(loggers) -} - -// Cmd is a simple wrapper. -type Cmd struct { - logger Logger - *exec.Cmd -} - -// CombinedOutput returns the output and logs. -func (c *Cmd) CombinedOutput() ([]byte, error) { - out, err := c.Cmd.CombinedOutput() - if len(out) > 0 { - c.logger.Logf("output: %s", string(out)) - } - if err != nil { - c.logger.Logf("error: %v", err) - } - return out, err -} - -// Command is a simple wrapper around exec.Command, that logs. -func Command(logger Logger, args ...string) *Cmd { - logger.Logf("command: %s", strings.Join(args, " ")) - return &Cmd{ - logger: logger, - Cmd: exec.Command(args[0], args[1:]...), - } -} - -// TestConfig returns the default configuration to use in tests. Note that -// 'RootDir' must be set by caller if required. -func TestConfig(t *testing.T) *config.Config { - logDir := os.TempDir() - if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - logDir = dir + "/" - } - - // Only register flags if config is being used. Otherwise anyone that uses - // testutil will get flags registered and they may conflict. - config.RegisterFlags() - - conf, err := config.NewFromFlags() - if err != nil { - panic(err) - } - // Change test defaults. - conf.Debug = true - conf.DebugLog = path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%") - conf.LogPackets = true - conf.Network = config.NetworkNone - conf.Strace = true - conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true - conf.WatchdogAction = watchdog.Panic - return conf -} - -// NewSpecWithArgs creates a simple spec with the given args suitable for use -// in tests. -func NewSpecWithArgs(args ...string) *specs.Spec { - return &specs.Spec{ - // The host filesystem root is the container root. - Root: &specs.Root{ - Path: "/", - Readonly: true, - }, - Process: &specs.Process{ - Args: args, - Env: []string{ - "PATH=" + os.Getenv("PATH"), - }, - Capabilities: specutils.AllCapabilities(), - }, - Mounts: []specs.Mount{ - // Hide the host /etc to avoid any side-effects. - // For example, bash reads /etc/passwd and if it is - // very big, tests can fail by timeout. - { - Type: "tmpfs", - Destination: "/etc", - }, - // Root is readonly, but many tests want to write to tmpdir. - // This creates a writable mount inside the root. Also, when tmpdir points - // to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs - // inside the sentry. - { - Type: "bind", - Destination: TmpDir(), - Source: TmpDir(), - }, - }, - Hostname: "runsc-test-hostname", - } -} - -// SetupRootDir creates a root directory for containers. -func SetupRootDir() (string, func(), error) { - rootDir, err := ioutil.TempDir(TmpDir(), "containers") - if err != nil { - return "", nil, fmt.Errorf("error creating root dir: %v", err) - } - return rootDir, func() { os.RemoveAll(rootDir) }, nil -} - -// SetupContainer creates a bundle and root dir for the container, generates a -// test config, and writes the spec to config.json in the bundle dir. -func SetupContainer(spec *specs.Spec, conf *config.Config) (rootDir, bundleDir string, cleanup func(), err error) { - rootDir, rootCleanup, err := SetupRootDir() - if err != nil { - return "", "", nil, err - } - conf.RootDir = rootDir - bundleDir, bundleCleanup, err := SetupBundleDir(spec) - if err != nil { - rootCleanup() - return "", "", nil, err - } - return rootDir, bundleDir, func() { - bundleCleanup() - rootCleanup() - }, err -} - -// SetupBundleDir creates a bundle dir and writes the spec to config.json. -func SetupBundleDir(spec *specs.Spec) (string, func(), error) { - bundleDir, err := ioutil.TempDir(TmpDir(), "bundle") - if err != nil { - return "", nil, fmt.Errorf("error creating bundle dir: %v", err) - } - cleanup := func() { os.RemoveAll(bundleDir) } - if err := writeSpec(bundleDir, spec); err != nil { - cleanup() - return "", nil, fmt.Errorf("error writing spec: %v", err) - } - return bundleDir, cleanup, nil -} - -// writeSpec writes the spec to disk in the given directory. -func writeSpec(dir string, spec *specs.Spec) error { - b, err := json.Marshal(spec) - if err != nil { - return err - } - return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755) -} - -// idRandomSrc is a pseudo random generator used to in RandomID. -var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano())) - -// idRandomSrcMtx is the mutex protecting idRandomSrc.Read from being used -// concurrently in differnt goroutines. -var idRandomSrcMtx sync.Mutex - -// RandomID returns 20 random bytes following the given prefix. -func RandomID(prefix string) string { - // Read 20 random bytes. - b := make([]byte, 20) - // Rand.Read is not safe for concurrent use. Packetimpact tests can be run in - // parallel now, so we have to protect the Read with a mutex. Otherwise we'll - // run into name conflicts. - // https://golang.org/pkg/math/rand/#Rand.Read - idRandomSrcMtx.Lock() - // "[Read] always returns len(p) and a nil error." --godoc - if _, err := idRandomSrc.Read(b); err != nil { - idRandomSrcMtx.Unlock() - panic("rand.Read failed: " + err.Error()) - } - idRandomSrcMtx.Unlock() - if prefix != "" { - prefix = prefix + "-" - } - return fmt.Sprintf("%s%s", prefix, base32.StdEncoding.EncodeToString(b)) -} - -// RandomContainerID generates a random container id for each test. -// -// The container id is used to create an abstract unix domain socket, which -// must be unique. While the container forbids creating two containers with the -// same name, sometimes between test runs the socket does not get cleaned up -// quickly enough, causing container creation to fail. -func RandomContainerID() string { - return RandomID("test-container") -} - -// Copy copies file from src to dst. -func Copy(src, dst string) error { - in, err := os.Open(src) - if err != nil { - return err - } - defer in.Close() - - st, err := in.Stat() - if err != nil { - return err - } - - out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, st.Mode().Perm()) - if err != nil { - return err - } - defer out.Close() - - // Mirror the local user's permissions across all users. This is - // because as we inject things into the container, the UID/GID will - // change. Also, the build system may generate artifacts with different - // modes. At the top-level (volume mapping) we have a big read-only - // knob that can be applied to prevent modifications. - // - // Note that this must be done via a separate Chmod call, otherwise the - // current process's umask will get in the way. - var mode os.FileMode - if st.Mode()&0100 != 0 { - mode |= 0111 - } - if st.Mode()&0200 != 0 { - mode |= 0222 - } - if st.Mode()&0400 != 0 { - mode |= 0444 - } - if err := os.Chmod(dst, mode); err != nil { - return err - } - - _, err = io.Copy(out, in) - return err -} - -// Poll is a shorthand function to poll for something with given timeout. -func Poll(cb func() error, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return PollContext(ctx, cb) -} - -// PollContext is like Poll, but takes a context instead of a timeout. -func PollContext(ctx context.Context, cb func() error) error { - b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) - return backoff.Retry(cb, b) -} - -// WaitForHTTP tries GET requests on a port until the call succeeds or timeout. -func WaitForHTTP(ip string, port int, timeout time.Duration) error { - cb := func() error { - c := &http.Client{ - // Calculate timeout to be able to do minimum 5 attempts. - Timeout: timeout / 5, - } - url := fmt.Sprintf("http://%s:%d/", ip, port) - resp, err := c.Get(url) - if err != nil { - log.Printf("Waiting %s: %v", url, err) - return err - } - resp.Body.Close() - return nil - } - return Poll(cb, timeout) -} - -// Reaper reaps child processes. -type Reaper struct { - // mu protects ch, which will be nil if the reaper is not running. - mu sync.Mutex - ch chan os.Signal -} - -// Start starts reaping child processes. -func (r *Reaper) Start() { - r.mu.Lock() - defer r.mu.Unlock() - - if r.ch != nil { - panic("reaper.Start called on a running reaper") - } - - r.ch = make(chan os.Signal, 1) - signal.Notify(r.ch, unix.SIGCHLD) - - go func() { - for { - r.mu.Lock() - ch := r.ch - r.mu.Unlock() - if ch == nil { - return - } - - _, ok := <-ch - if !ok { - // Channel closed. - return - } - for { - cpid, _ := unix.Wait4(-1, nil, unix.WNOHANG, nil) - if cpid < 1 { - break - } - } - } - }() -} - -// Stop stops reaping child processes. -func (r *Reaper) Stop() { - r.mu.Lock() - defer r.mu.Unlock() - - if r.ch == nil { - panic("reaper.Stop called on a stopped reaper") - } - - signal.Stop(r.ch) - close(r.ch) - r.ch = nil -} - -// StartReaper is a helper that starts a new Reaper and returns a function to -// stop it. -func StartReaper() func() { - r := &Reaper{} - r.Start() - return r.Stop -} - -// WaitUntilRead reads from the given reader until the wanted string is found -// or until timeout. -func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error { - sc := bufio.NewScanner(r) - // done must be accessed atomically. A value greater than 0 indicates - // that the read loop can exit. - doneCh := make(chan bool) - defer close(doneCh) - go func() { - for sc.Scan() { - t := sc.Text() - if strings.Contains(t, want) { - doneCh <- true - return - } - select { - case <-doneCh: - return - default: - } - } - doneCh <- false - }() - - select { - case <-time.After(timeout): - return fmt.Errorf("timeout waiting to read %q", want) - case res := <-doneCh: - if !res { - return fmt.Errorf("reader closed while waiting to read %q", want) - } - return nil - } -} - -// KillCommand kills the process running cmd unless it hasn't been started. It -// returns an error if it cannot kill the process unless the reason is that the -// process has already exited. -// -// KillCommand will also reap the process. -func KillCommand(cmd *exec.Cmd) error { - if cmd.Process == nil { - return nil - } - if err := cmd.Process.Kill(); err != nil { - if !strings.Contains(err.Error(), "process already finished") { - return fmt.Errorf("failed to kill process %v: %v", cmd, err) - } - } - return cmd.Wait() -} - -// WriteTmpFile writes text to a temporary file, closes the file, and returns -// the name of the file. A cleanup function is also returned. -func WriteTmpFile(pattern, text string) (string, func(), error) { - file, err := ioutil.TempFile(TmpDir(), pattern) - if err != nil { - return "", nil, err - } - defer file.Close() - if _, err := file.Write([]byte(text)); err != nil { - return "", nil, err - } - return file.Name(), func() { os.RemoveAll(file.Name()) }, nil -} - -// IsStatic returns true iff the given file is a static binary. -func IsStatic(filename string) (bool, error) { - f, err := elf.Open(filename) - if err != nil { - return false, err - } - for _, prog := range f.Progs { - if prog.Type == elf.PT_INTERP { - return false, nil // Has interpreter. - } - } - return true, nil -} - -// TouchShardStatusFile indicates to Bazel that the test runner supports -// sharding by creating or updating the last modified date of the file -// specified by TEST_SHARD_STATUS_FILE. -// -// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner. -func TouchShardStatusFile() error { - if statusFile, ok := os.LookupEnv("TEST_SHARD_STATUS_FILE"); ok { - cmd := exec.Command("touch", statusFile) - if b, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error()) - } - } - return nil -} - -// TestIndicesForShard returns indices for this test shard based on the -// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars, as well as -// the passed partition flags. -// -// If either of the env vars are not present, then the function will return all -// tests. If there are more shards than there are tests, then the returned list -// may be empty. -func TestIndicesForShard(numTests int) ([]int, error) { - var ( - shardIndex = 0 - shardTotal = 1 - ) - - indexStr, indexOk := os.LookupEnv("TEST_SHARD_INDEX") - totalStr, totalOk := os.LookupEnv("TEST_TOTAL_SHARDS") - if indexOk && totalOk { - // Parse index and total to ints. - var err error - shardIndex, err = strconv.Atoi(indexStr) - if err != nil { - return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err) - } - shardTotal, err = strconv.Atoi(totalStr) - if err != nil { - return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err) - } - } - - // Combine with the partitions. - partitionSize := shardTotal - shardTotal = (*totalPartitions) * shardTotal - shardIndex = partitionSize*(*partition-1) + shardIndex - - // Calculate! - var indices []int - numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal))) - for i := 0; i < numBlocks; i++ { - pick := i*shardTotal + shardIndex - if pick < numTests { - indices = append(indices, pick) - } - } - return indices, nil -} diff --git a/pkg/test/testutil/testutil_runfiles.go b/pkg/test/testutil/testutil_runfiles.go deleted file mode 100644 index 1dbd48a47..000000000 --- a/pkg/test/testutil/testutil_runfiles.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build go1.1 -// +build go1.1 - -package testutil - -import ( - "fmt" - "os" - "path/filepath" -) - -// FindFile searchs for a file inside the test run environment. It returns the -// full path to the file. It fails if none or more than one file is found. -func FindFile(path string) (string, error) { - wd, err := os.Getwd() - if err != nil { - return "", err - } - - // The test root is demarcated by a path element called "__main__". Search for - // it backwards from the working directory. - root := wd - for { - dir, name := filepath.Split(root) - if name == "__main__" { - break - } - if len(dir) == 0 { - return "", fmt.Errorf("directory __main__ not found in %q", wd) - } - // Remove ending slash to loop around. - root = dir[:len(dir)-1] - } - - // Annoyingly, bazel adds the build type to the directory path for go - // binaries, but not for c++ binaries. We use two different patterns to - // to find our file. - patterns := []string{ - // Try the obvious path first. - filepath.Join(root, path), - // If it was a go binary, use a wildcard to match the build - // type. The pattern is: /test-path/__main__/directories/*/file. - filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)), - } - - for _, p := range patterns { - matches, err := filepath.Glob(p) - if err != nil { - // "The only possible returned error is ErrBadPattern, - // when pattern is malformed." -godoc - return "", fmt.Errorf("error globbing %q: %v", p, err) - } - switch len(matches) { - case 0: - // Try the next pattern. - case 1: - // We found it. - return matches[0], nil - default: - return "", fmt.Errorf("more than one match found for %q: %s", path, matches) - } - } - return "", fmt.Errorf("file %q not found", path) -} diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD deleted file mode 100644 index 234125c38..000000000 --- a/pkg/unet/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "unet", - srcs = [ - "unet.go", - "unet_unsafe.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "unet_test", - size = "small", - srcs = [ - "unet_test.go", - ], - library = ":unet", - deps = [ - "//pkg/sync", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/unet/unet_state_autogen.go b/pkg/unet/unet_state_autogen.go new file mode 100644 index 000000000..9bbf31d35 --- /dev/null +++ b/pkg/unet/unet_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package unet diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go deleted file mode 100644 index 9875a3cda..000000000 --- a/pkg/unet/unet_test.go +++ /dev/null @@ -1,739 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package unet - -import ( - "io/ioutil" - "os" - "path/filepath" - "reflect" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sync" -) - -func randomFilename() (string, error) { - // Return a randomly generated file in the test dir. - f, err := ioutil.TempFile("", "unet-test") - if err != nil { - return "", err - } - file := f.Name() - os.Remove(file) - f.Close() - - cwd, err := os.Getwd() - if err != nil { - return "", err - } - - // NOTE(b/26918832): We try to use relative path if possible. This is - // to help conforming to the unix path length limit. - if rel, err := filepath.Rel(cwd, file); err == nil { - return rel, nil - } - - return file, nil -} - -func TestConnectFailure(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - if _, err := Connect(name, false); err == nil { - t.Fatalf("Connect was successful, expected err") - } -} - -func TestBindFailure(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("First bind failed, got err %v expected nil", err) - } - defer ss.Close() - - if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("Second bind succeeded, expected non-nil err") - } -} - -func TestMultipleAccept(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("First bind failed, got err %v expected nil", err) - } - defer ss.Close() - - // Connect backlog times asynchronously. - var wg sync.WaitGroup - defer wg.Wait() - for i := 0; i < backlog; i++ { - wg.Add(1) - go func() { - defer wg.Done() - s, err := Connect(name, false) - if err != nil { - t.Errorf("Connect failed, got err %v expected nil", err) - return - } - s.Close() - }() - } - - // Accept backlog times. - for i := 0; i < backlog; i++ { - s, err := ss.Accept() - if err != nil { - t.Errorf("Accept failed, got err %v expected nil", err) - continue - } - s.Close() - } -} - -func TestServerClose(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("First bind failed, got err %v expected nil", err) - } - - // Make sure the first close succeeds. - if err := ss.Close(); err != nil { - t.Fatalf("First close failed, got err %v expected nil", err) - } - - // The second one should fail. - if err := ss.Close(); err == nil { - t.Fatalf("Second close succeeded, expected non-nil err") - } -} - -func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - // Bind a server. - ss, err := BindAndListen(name, packet) - if err != nil { - t.Fatalf("Error binding, got %v expected nil", err) - } - defer ss.Close() - - // Accept a client. - acceptSocket := make(chan *Socket) - acceptErr := make(chan error) - go func() { - server, err := ss.Accept() - if err != nil { - acceptErr <- err - } - acceptSocket <- server - }() - - // Connect the client. - client, err := Connect(name, packet) - if err != nil { - t.Fatalf("Error connecting, got %v expected nil", err) - } - - // Grab the server handle. - select { - case server := <-acceptSocket: - return server, client - case err := <-acceptErr: - t.Fatalf("Accept error: %v", err) - } - panic("unreachable") -} - -func TestSendRecv(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'b'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) - } -} - -// TestSymmetric exists to assert that the two sockets received from socketPair -// are interchangeable. They should be, this just provides a basic sanity check -// by running TestSendRecv "backwards". -func TestSymmetric(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the server. - w := server.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the client. - b := [][]byte{{'b'}} - r := client.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) - } -} - -func TestPacket(t *testing.T) { - server, client := socketPair(t, true) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Write on the client again. - w = client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the server. - // - // This should only get back a single byte, despite the buffer - // being size two. This is because it's a _packet_ buffer. - b := [][]byte{{'b', 'b'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) - } - - // Do it again. - r = server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) - } -} - -func TestClose(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - - // Make sure the first close succeeds. - if err := client.Close(); err != nil { - t.Fatalf("First close failed, got err %v expected nil", err) - } - - // The second one should fail. - if err := client.Close(); err == nil { - t.Fatalf("Second close succeeded, expected non-nil err") - } -} - -func TestNonBlockingSend(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Try up to 1000 writes, of 1000 bytes. - blockCount := 0 - for i := 0; i < 1000; i++ { - w := client.Writer(false) - if n, err := w.WriteVec([][]byte{make([]byte, 1000)}); n != 1000 || err != nil { - if err == unix.EWOULDBLOCK || err == unix.EAGAIN { - // We're good. That's what we wanted. - blockCount++ - } else { - t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err) - } - } - } - - if blockCount == 1000 { - // Shouldn't have _always_ blocked. - t.Fatalf("Socket always blocked!") - } else if blockCount == 0 { - // Should have started blocking eventually. - t.Fatalf("Socket never blocked!") - } -} - -func TestNonBlockingRecv(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - b := [][]byte{{'b'}} - r := client.Reader(false) - - // Expected to block immediately. - _, err := r.ReadVec(b) - if err != unix.EWOULDBLOCK && err != unix.EAGAIN { - t.Fatalf("Read didn't block, got err %v expected blocking err", err) - } - - // Put some data in the pipe. - w := server.Writer(false) - if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Expect it not to block. - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Expect it to return a block error again. - r = client.Reader(false) - _, err = r.ReadVec(b) - if err != unix.EWOULDBLOCK && err != unix.EAGAIN { - t.Fatalf("Read didn't block, got err %v expected blocking err", err) - } -} - -func TestRecvVectors(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'c'}, {'c'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) - } - if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) - } -} - -func TestSendVectors(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'c', 'c'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) - } - if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) - } -} - -func TestSendFDsNotEnabled(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the server. - w := server.Writer(true) - w.PackFDs(0, 1, 2) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the client, without enabling FDs. - b := [][]byte{{'b'}} - r := client.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) - } - - // Make sure the FDs are not received. - fds, err := r.ExtractFDs() - if len(fds) != 0 || err != nil { - t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) - } -} - -func sendFDs(t *testing.T, s *Socket, fds []int) { - w := s.Writer(true) - w.PackFDs(fds...) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err) - } -} - -func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { - expected := len(origFDs) - - // Count the number of FDs. - preEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("Can't readdir, got err %v expected nil", err) - } - - // Read on the client. - b := [][]byte{{'b'}} - r := s.Reader(true) - if enableSize >= 0 { - r.EnableFDs(enableSize) - } - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) - } - - // Count the new number of FDs. - postEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("Can't readdir, got err %v expected nil", err) - } - if len(preEntries)+expected != len(postEntries) { - t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) - } - - // Make sure the FDs are there. - fds, err := r.ExtractFDs() - if len(fds) != expected || err != nil { - t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) - } - - // Make sure they are different from the originals. - for i := 0; i < len(fds); i++ { - if fds[i] == origFDs[i] { - t.Errorf("Got original fd for index %d, expected different", i) - } - } - - // Make sure they can be accessed as expected. - for i := 0; i < len(fds); i++ { - var st unix.Stat_t - if err := unix.Fstat(fds[i], &st); err != nil { - t.Errorf("fds[%d] can't be stated, got err %v expected nil", i, err) - } - } - - // Close them off. - r.CloseFDs() - - // Make sure the count is back to normal. - finalEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("Can't readdir, got err %v expected nil", err) - } - if len(finalEntries) != len(preEntries) { - t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) - } -} - -func TestFDsSingle(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 1, []int{0}) -} - -func TestFDsMultiple(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Basic case, multiple FDs. - sendFDs(t, server, []int{0, 1, 2}) - recvFDs(t, client, 3, []int{0, 1, 2}) -} - -// See TestSymmetric above. -func TestFDsSymmetric(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0, 1, 2}) - recvFDs(t, client, 3, []int{0, 1, 2}) -} - -func TestFDsReceiveLargeBuffer(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 3, []int{0}) -} - -func TestFDsReceiveSmallBuffer(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0, 1, 2}) - - // Per the spec, we may still receive more than the buffer. In fact, - // it'll be rounded up and we can receive two with a size one buffer. - recvFDs(t, client, 1, []int{0, 1}) -} - -func TestFDsReceiveNotEnabled(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, -1, []int{}) -} - -func TestFDsReceiveSizeZero(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 0, []int{}) -} - -func TestGetPeerCred(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - want := &unix.Ucred{ - Pid: int32(os.Getpid()), - Uid: uint32(os.Getuid()), - Gid: uint32(os.Getgid()), - } - - if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) - } -} - -func newClosedSocket() (*Socket, error) { - fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) - if err != nil { - return nil, err - } - - s, err := NewSocket(fd) - if err != nil { - unix.Close(fd) - return nil, err - } - - return s, s.Close() -} - -func TestGetPeerCredFailure(t *testing.T) { - s, err := newClosedSocket() - if err != nil { - t.Fatalf("newClosedSocket got error %v want nil", err) - } - - want := "bad file descriptor" - if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("s.GetPeerCred() = %v, want = %s", err, want) - } -} - -func TestAcceptClosed(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("Bind failed, got err %v expected nil", err) - } - - if err := ss.Close(); err != nil { - t.Fatalf("Close failed, got err %v expected nil", err) - } - - if _, err := ss.Accept(); err == nil { - t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) - } -} - -func TestCloseAfterAcceptStart(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("Bind failed, got err %v expected nil", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - time.Sleep(50 * time.Millisecond) - if err := ss.Close(); err != nil { - t.Errorf("Close failed, got err %v expected nil", err) - } - }() - - if _, err := ss.Accept(); err == nil { - t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) - } - - wg.Wait() -} - -func TestReleaseAfterAcceptStart(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("Unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("Bind failed, got err %v expected nil", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - time.Sleep(50 * time.Millisecond) - fd, err := ss.Release() - if err != nil { - t.Errorf("Release failed, got err %v expected nil", err) - } - unix.Close(fd) - }() - - if _, err := ss.Accept(); err == nil { - t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) - } - - wg.Wait() -} - -func TestControlMessage(t *testing.T) { - for i := 0; i <= 10; i++ { - var want []int - for j := 0; j < i; j++ { - want = append(want, i+j+1) - } - - var cm ControlMessage - cm.EnableFDs(i) - cm.PackFDs(want...) - got, err := cm.ExtractFDs() - if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) - } - } -} - -func benchmarkSendRecv(b *testing.B, packet bool) { - server, client, err := SocketPair(packet) - if err != nil { - b.Fatalf("SocketPair: got %v, wanted nil", err) - } - defer server.Close() - defer client.Close() - go func() { - buf := make([]byte, 1) - for i := 0; i < b.N; i++ { - n, err := server.Read(buf) - if n != 1 || err != nil { - b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err) - return - } - n, err = server.Write(buf) - if n != 1 || err != nil { - b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err) - return - } - } - }() - buf := make([]byte, 1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - n, err := client.Write(buf) - if n != 1 || err != nil { - b.Fatalf("client.Write: got (%d, %v), wanted (1, nil)", n, err) - } - n, err = client.Read(buf) - if n != 1 || err != nil { - b.Fatalf("client.Read: got (%d, %v), wanted (1, nil)", n, err) - } - } -} - -func BenchmarkSendRecvStream(b *testing.B) { - benchmarkSendRecv(b, false) -} - -func BenchmarkSendRecvPacket(b *testing.B) { - benchmarkSendRecv(b, true) -} diff --git a/pkg/unet/unet_unsafe_state_autogen.go b/pkg/unet/unet_unsafe_state_autogen.go new file mode 100644 index 000000000..9bbf31d35 --- /dev/null +++ b/pkg/unet/unet_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package unet diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD deleted file mode 100644 index 850c34ed0..000000000 --- a/pkg/urpc/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "urpc", - srcs = ["urpc.go"], - visibility = ["//:sandbox"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/sync", - "//pkg/unet", - ], -) - -go_test( - name = "urpc_test", - size = "small", - srcs = ["urpc_test.go"], - library = ":urpc", - deps = ["//pkg/unet"], -) diff --git a/pkg/urpc/urpc_state_autogen.go b/pkg/urpc/urpc_state_autogen.go new file mode 100644 index 000000000..5fdca6717 --- /dev/null +++ b/pkg/urpc/urpc_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package urpc diff --git a/pkg/urpc/urpc_test.go b/pkg/urpc/urpc_test.go deleted file mode 100644 index c6c7ce9d4..000000000 --- a/pkg/urpc/urpc_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package urpc - -import ( - "errors" - "os" - "testing" - - "gvisor.dev/gvisor/pkg/unet" -) - -type test struct { -} - -type testArg struct { - StringArg string - IntArg int - FilePayload -} - -type testResult struct { - StringResult string - IntResult int - FilePayload -} - -func (t test) Func(a *testArg, r *testResult) error { - r.StringResult = a.StringArg - r.IntResult = a.IntArg - return nil -} - -func (t test) Err(a *testArg, r *testResult) error { - return errors.New("test error") -} - -func (t test) FailNoFile(a *testArg, r *testResult) error { - if a.Files == nil { - return errors.New("no file found") - } - - return nil -} - -func (t test) SendFile(a *testArg, r *testResult) error { - r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} - return nil -} - -func (t test) TooManyFiles(a *testArg, r *testResult) error { - for i := 0; i <= maxFiles; i++ { - r.Files = append(r.Files, os.Stdin) - } - return nil -} - -func startServer(socket *unet.Socket) { - s := NewServer() - s.Register(test{}) - s.StartHandling(socket) -} - -func testClient() (*Client, error) { - serverSock, clientSock, err := unet.SocketPair(false) - if err != nil { - return nil, err - } - startServer(serverSock) - - return NewClient(clientSock), nil -} - -func TestCall(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Func", &testArg{}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.StringResult != "" || r.IntResult != 0 { - t.Errorf("unexpected result, got %v expected zero value", r) - } - if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.StringResult != "hello" { - t.Errorf("unexpected result, got %v expected hello", r.StringResult) - } - if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.IntResult != 1 { - t.Errorf("unexpected result, got %v expected 1", r.IntResult) - } -} - -func TestUnknownMethod(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Unknown", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != ErrUnknownMethod.Error() { - t.Errorf("expected test error, got %v", err) - } -} - -func TestErr(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Err", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != "test error" { - t.Errorf("expected test error, got %v", err) - } -} - -func TestSendFile(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } - if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil { - t.Errorf("expected nil err, got %v", err) - } -} - -func TestRecvFile(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.SendFile", &testArg{}, &r); err != nil { - t.Errorf("expected nil err, got %v", err) - } - if r.Files == nil { - t.Errorf("expected file, got nil") - } -} - -func TestShutdown(t *testing.T) { - serverSock, clientSock, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - clientSock.Close() - - s := NewServer() - if err := s.Handle(serverSock); err == nil { - t.Errorf("expected non-nil err, got nil") - } -} - -func TestTooManyFiles(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - var a testArg - for i := 0; i <= maxFiles; i++ { - a.Files = append(a.Files, os.Stdin) - } - - // Client-side error. - if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles { - t.Errorf("expected ErrTooManyFiles, got %v", err) - } - - // Server-side error. - if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != "too many files" { - t.Errorf("expected too many files, got %v", err.Error()) - } -} diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD deleted file mode 100644 index 9c37a9626..000000000 --- a/pkg/usermem/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "usermem", - srcs = [ - "bytes_io.go", - "bytes_io_unsafe.go", - "marshal.go", - "usermem.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/atomicbitops", - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/gohacks", - "//pkg/hostarch", - "//pkg/safemem", - ], -) - -go_test( - name = "usermem_test", - size = "small", - srcs = [ - "usermem_test.go", - ], - library = ":usermem", - deps = [ - "//pkg/context", - "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/safemem", - ], -) diff --git a/pkg/usermem/README.md b/pkg/usermem/README.md deleted file mode 100644 index f6d2137eb..000000000 --- a/pkg/usermem/README.md +++ /dev/null @@ -1,31 +0,0 @@ -This package defines primitives for sentry access to application memory. - -Major types: - -- The `IO` interface represents a virtual address space and provides I/O - methods on that address space. `IO` is the lowest-level primitive. The - primary implementation of the `IO` interface is `mm.MemoryManager`. - -- `IOSequence` represents a collection of individually-contiguous address - ranges in a `IO` that is operated on sequentially, analogous to Linux's - `struct iov_iter`. - -Major usage patterns: - -- Access to a task's virtual memory, subject to the application's memory - protections and while running on that task's goroutine, from a context that - is at or above the level of the `kernel` package (e.g. most syscall - implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers - defined in `kernel/task_usermem.go`. - -- Access to a task's virtual memory, from a context that is at or above the - level of the `kernel` package, but where any of the above constraints does - not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory - protections); obtain the task's `mm.MemoryManager` by calling - `kernel.Task.MemoryManager`, and call its `IO` methods directly. - -- Access to a task's virtual memory, from a context that is below the level of - the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments - from higher layers, usually in the form of an `IOSequence`. The - `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions - in `kernel/task_usermem.go` are convenience functions for doing so. diff --git a/pkg/usermem/usermem_state_autogen.go b/pkg/usermem/usermem_state_autogen.go new file mode 100644 index 000000000..62f8af4c9 --- /dev/null +++ b/pkg/usermem/usermem_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package usermem diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go deleted file mode 100644 index a5e2fe69e..000000000 --- a/pkg/usermem/usermem_test.go +++ /dev/null @@ -1,407 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "bytes" - "fmt" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/errors/linuxerr" - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/safemem" -) - -// newContext returns a context.Context that we can use in these tests (we -// can't use contexttest because it depends on usermem). -func newContext() context.Context { - return context.Background() -} - -func newBytesIOString(s string) *BytesIO { - return &BytesIO{[]byte(s)} -} - -func TestBytesIOCopyOutSuccess(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN, wantErr := 2, linuxerr.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInSuccess(t *testing.T) { - b := newBytesIOString("AfooE") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInFailure(t *testing.T) { - b := newBytesIOString("Afo") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN, wantErr := 2, linuxerr.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutSuccess(t *testing.T) { - b := newBytesIOString("ABCD") - n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) - if wantN, wantErr := int64(2), linuxerr.EFAULT; n != wantN || err != wantErr { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromSuccess(t *testing.T) { - b := newBytesIOString("ABCDEFGH") - n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromFailure(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) - if wantN, wantErr := int64(4), linuxerr.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToSuccess(t *testing.T) { - b := newBytesIOString("AfoobarH") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToFailure(t *testing.T) { - b := newBytesIOString("Afoob") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN, wantErr := int64(4), linuxerr.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -type testStruct struct { - Int8 int8 - Uint8 uint8 - Int16 int16 - Uint16 uint16 - Int32 int32 - Uint32 uint32 - Int64 int64 - Uint64 uint64 -} - -func TestCopyStringInShort(t *testing.T) { - // Tests for string length <= copyStringIncrement. - want := strings.Repeat("A", copyStringIncrement-2) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInLong(t *testing.T) { - // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen - // (requiring multiple calls to IO.CopyIn()). - want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInVeryLong(t *testing.T) { - // Tests for string length > copyStringMaxInitBufLen (requiring buffer - // reallocation). - want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { - want := strings.Repeat("A", copyStringIncrement-1) - got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) - if wantErr := linuxerr.EFAULT; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyStringInTruncatedByMaxlen(t *testing.T) { - got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) - if want, wantErr := strings.Repeat("A", 5), linuxerr.ENAMETOOLONG; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyInt32StringsInVec(t *testing.T) { - for _, test := range []struct { - str string - n int - initial []int32 - final []int32 - }{ - { - str: "100 200", - n: len("100 200"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Fewer values ok - str: "100", - n: len("100"), - initial: []int32{1, 2}, - final: []int32{100, 2}, - }, - { - // Extra values ok - str: "100 200 300", - n: len("100 200 "), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Leading and trailing whitespace ok - str: " 100\t200\n", - n: len(" 100\t200\n"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - src := BytesIOSequence([]byte(test.str)) - dsts := append([]int32(nil), test.initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) - } - if !reflect.DeepEqual(dsts, test.final) { - t.Errorf("dsts: got %v, wanted %v", dsts, test.final) - } - }) - } -} - -func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { - for _, s := range []string{"", "\n", "a123"} { - t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { - src := BytesIOSequence([]byte(s)) - initial := []int32{1, 2} - dsts := append([]int32(nil), initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL) - } - if !reflect.DeepEqual(dsts, initial) { - t.Errorf("dsts: got %v, wanted %v", dsts, initial) - } - }) - } -} - -func TestIOSequenceCopyOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // CopyOut limited by len(src). - n, err := s.CopyOut(newContext(), []byte("fo")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyOut limited by s.NumBytes(). - n, err = s.CopyOut(newContext(), []byte("obar")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foob"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceCopyIn(t *testing.T) { - s := BytesIOSequence([]byte("foob")) - dst := []byte("ABCDEF") - - // CopyIn limited by len(dst). - n, err := s.CopyIn(newContext(), dst[:2]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCDEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyIn limited by s.Remaining(). - n, err = s.CopyIn(newContext(), dst[2:]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foobEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceZeroOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // ZeroOut limited by toZero. - n, err := s.ZeroOut(newContext(), 2) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // ZeroOut limited by s.NumBytes(). - n, err = s.ZeroOut(newContext(), 4) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceTakeFirst(t *testing.T) { - s := BytesIOSequence([]byte("foobar")) - if got, want := s.NumBytes(), int64(6); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - s = s.TakeFirst(3) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // TakeFirst(n) where n > s.NumBytes() is a no-op. - s = s.TakeFirst(9) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - var dst [3]byte - n, err := s.CopyIn(newContext(), dst[:]) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } - s = s.DropFirst(3) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} diff --git a/pkg/usermem/usermem_unsafe_state_autogen.go b/pkg/usermem/usermem_unsafe_state_autogen.go new file mode 100644 index 000000000..62f8af4c9 --- /dev/null +++ b/pkg/usermem/usermem_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package usermem diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD deleted file mode 100644 index 852480a09..000000000 --- a/pkg/waiter/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "waiter", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Entry", - "Linker": "*Entry", - }, -) - -go_library( - name = "waiter", - srcs = [ - "waiter.go", - "waiter_list.go", - ], - visibility = ["//visibility:public"], - deps = ["//pkg/sync"], -) - -go_test( - name = "waiter_test", - size = "small", - srcs = [ - "waiter_test.go", - ], - library = ":waiter", -) diff --git a/pkg/waiter/waiter_list.go b/pkg/waiter/waiter_list.go new file mode 100644 index 000000000..f83aeb49d --- /dev/null +++ b/pkg/waiter/waiter_list.go @@ -0,0 +1,221 @@ +package waiter + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Entry) *Entry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Entry + tail *Entry +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *waiterList) Front() *Entry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *waiterList) Back() *Entry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = (waiterElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *waiterList) PushFront(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *waiterList) PushBack(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *waiterList) InsertAfter(b, e *Entry) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *waiterList) InsertBefore(a, e *Entry) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *waiterList) Remove(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Entry + prev *Entry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *waiterEntry) Next() *Entry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *waiterEntry) Prev() *Entry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *waiterEntry) SetNext(elem *Entry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *waiterEntry) SetPrev(elem *Entry) { + e.prev = elem +} diff --git a/pkg/waiter/waiter_state_autogen.go b/pkg/waiter/waiter_state_autogen.go new file mode 100644 index 000000000..498b7e79e --- /dev/null +++ b/pkg/waiter/waiter_state_autogen.go @@ -0,0 +1,126 @@ +// automatically generated by stateify. + +package waiter + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *Entry) StateTypeName() string { + return "pkg/waiter.Entry" +} + +func (e *Entry) StateFields() []string { + return []string{ + "Callback", + "mask", + "waiterEntry", + } +} + +func (e *Entry) beforeSave() {} + +// +checklocksignore +func (e *Entry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Callback) + stateSinkObject.Save(1, &e.mask) + stateSinkObject.Save(2, &e.waiterEntry) +} + +func (e *Entry) afterLoad() {} + +// +checklocksignore +func (e *Entry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Callback) + stateSourceObject.Load(1, &e.mask) + stateSourceObject.Load(2, &e.waiterEntry) +} + +func (q *Queue) StateTypeName() string { + return "pkg/waiter.Queue" +} + +func (q *Queue) StateFields() []string { + return []string{ + "list", + } +} + +func (q *Queue) beforeSave() {} + +// +checklocksignore +func (q *Queue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.list) +} + +func (q *Queue) afterLoad() {} + +// +checklocksignore +func (q *Queue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.list) +} + +func (l *waiterList) StateTypeName() string { + return "pkg/waiter.waiterList" +} + +func (l *waiterList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *waiterList) beforeSave() {} + +// +checklocksignore +func (l *waiterList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *waiterList) afterLoad() {} + +// +checklocksignore +func (l *waiterList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *waiterEntry) StateTypeName() string { + return "pkg/waiter.waiterEntry" +} + +func (e *waiterEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *waiterEntry) beforeSave() {} + +// +checklocksignore +func (e *waiterEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *waiterEntry) afterLoad() {} + +// +checklocksignore +func (e *waiterEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*Entry)(nil)) + state.Register((*Queue)(nil)) + state.Register((*waiterList)(nil)) + state.Register((*waiterEntry)(nil)) +} diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go deleted file mode 100644 index 6928f28b4..000000000 --- a/pkg/waiter/waiter_test.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waiter - -import ( - "sync/atomic" - "testing" -) - -type callbackStub struct { - f func(e *Entry, m EventMask) -} - -// Callback implements EntryCallback.Callback. -func (c *callbackStub) Callback(e *Entry, m EventMask) { - c.f(e, m) -} - -func TestEmptyQueue(t *testing.T) { - var q Queue - - // Notify the zero-value of a queue. - q.Notify(EventIn) - - // Register then unregister a waiter, then notify the queue. - cnt := 0 - e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}} - q.EventRegister(&e, EventIn) - q.EventUnregister(&e) - q.Notify(EventIn) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestMask(t *testing.T) { - // Register a waiter. - var q Queue - var cnt int - e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}} - q.EventRegister(&e, EventIn|EventErr) - - // Notify with an overlapping mask. - cnt = 0 - q.Notify(EventIn | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a subset mask. - cnt = 0 - q.Notify(EventIn) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a superset mask. - cnt = 0 - q.Notify(EventIn | EventErr | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with the exact same mask. - cnt = 0 - q.Notify(EventIn | EventErr) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a disjoint mask. - cnt = 0 - q.Notify(EventOut | EventHUp) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestConcurrentRegistration(t *testing.T) { - var q Queue - var cnt int - const concurrency = 1000 - - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - ch3 := make(chan struct{}) - - // Create goroutines that will all register/unregister concurrently. - for i := 0; i < concurrency; i++ { - go func() { - var e Entry - e.Callback = &callbackStub{func(entry *Entry, mask EventMask) { - cnt++ - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - if mask != EventIn { - t.Errorf("mask = %#x want %#x", mask, EventIn) - } - }} - - // Wait for notification, then register. - <-ch1 - q.EventRegister(&e, EventIn|EventErr) - - // Tell main goroutine that we're done registering. - ch2 <- struct{}{} - - // Wait for notification, then unregister. - <-ch3 - q.EventUnregister(&e) - - // Tell main goroutine that we're done unregistering. - ch2 <- struct{}{} - }() - } - - // Let the goroutines register. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } - - // Let the goroutine unregister. - close(ch3) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } -} - -func TestConcurrentNotification(t *testing.T) { - var q Queue - var cnt int32 - const concurrency = 1000 - const waiterCount = 1000 - - // Register waiters. - for i := 0; i < waiterCount; i++ { - var e Entry - e.Callback = &callbackStub{func(entry *Entry, mask EventMask) { - atomic.AddInt32(&cnt, 1) - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - if mask != EventIn { - t.Errorf("mask = %#x want %#x", mask, EventIn) - } - }} - - q.EventRegister(&e, EventIn|EventErr) - } - - // Launch notifiers. - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - for i := 0; i < concurrency; i++ { - go func() { - <-ch1 - q.Notify(EventIn) - ch2 <- struct{}{} - }() - } - - // Let notifiers go. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Check the count. - if cnt != concurrency*waiterCount { - t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) - } -} |