summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD3
-rw-r--r--pkg/abi/linux/futex.go18
-rw-r--r--pkg/abi/linux/netfilter.go146
-rw-r--r--pkg/abi/linux/socket.go17
-rw-r--r--pkg/sentry/arch/arch_aarch64.go23
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/arch/arch_arm64.go4
-rw-r--r--pkg/sentry/arch/arch_x86.go16
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD4
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go28
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go200
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go19
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go29
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go41
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go56
-rw-r--r--pkg/sentry/fsimpl/host/host.go7
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go4
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go35
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go7
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/futex/futex.go8
-rw-r--r--pkg/sentry/kernel/kernel.go5
-rw-r--r--pkg/sentry/kernel/task.go19
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go3
-rw-r--r--pkg/sentry/kernel/task_futex.go125
-rw-r--r--pkg/sentry/kernel/task_run.go17
-rw-r--r--pkg/sentry/kernel/task_work.go38
-rw-r--r--pkg/sentry/kernel/time/BUILD1
-rw-r--r--pkg/sentry/kernel/time/tcpip.go131
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go6
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/socket.go7
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go49
-rw-r--r--pkg/sentry/socket/netlink/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/socket.go14
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go224
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go16
-rw-r--r--pkg/sentry/socket/socket.go3
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go321
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go2
-rw-r--r--pkg/sentry/vfs/options.go17
-rw-r--r--pkg/sentry/vfs/permissions.go8
-rw-r--r--pkg/tcpip/header/icmpv4.go1
-rw-r--r--pkg/tcpip/header/icmpv6.go11
-rw-r--r--pkg/tcpip/link/channel/channel.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go36
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go8
-rw-r--r--pkg/tcpip/link/loopback/loopback.go3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go4
-rw-r--r--pkg/tcpip/link/nested/nested.go15
-rw-r--r--pkg/tcpip/link/nested/nested_test.go4
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go21
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/link/tun/device.go43
-rw-r--r--pkg/tcpip/link/waitable/waitable.go14
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/ip_test.go5
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go3
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go2
-rw-r--r--pkg/tcpip/stack/conntrack.go136
-rw-r--r--pkg/tcpip/stack/forwarder_test.go5
-rw-r--r--pkg/tcpip/stack/iptables.go173
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/iptables_types.go22
-rw-r--r--pkg/tcpip/stack/ndp.go140
-rw-r--r--pkg/tcpip/stack/ndp_test.go28
-rw-r--r--pkg/tcpip/stack/nic.go30
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go4
-rw-r--r--pkg/tcpip/stack/registration.go21
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/tcpip.go52
-rw-r--r--pkg/tcpip/time_unsafe.go30
-rw-r--r--pkg/tcpip/timer.go147
-rw-r--r--pkg/tcpip/timer_test.go91
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go54
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/connect.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go26
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/test/dockerutil/exec.go1
106 files changed, 2403 insertions, 671 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 2b789c4ec..a4bb62013 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -72,6 +72,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
index 08bfde3b5..8138088a6 100644
--- a/pkg/abi/linux/futex.go
+++ b/pkg/abi/linux/futex.go
@@ -60,3 +60,21 @@ const (
FUTEX_WAITERS = 0x80000000
FUTEX_OWNER_DIED = 0x40000000
)
+
+// FUTEX_BITSET_MATCH_ANY has all bits set.
+const FUTEX_BITSET_MATCH_ANY = 0xffffffff
+
+// ROBUST_LIST_LIMIT protects against a deliberately circular list.
+const ROBUST_LIST_LIMIT = 2048
+
+// RobustListHead corresponds to Linux's struct robust_list_head.
+//
+// +marshal
+type RobustListHead struct {
+ List uint64
+ FutexOffset uint64
+ ListOpPending uint64
+}
+
+// SizeOfRobustListHead is the size of a RobustListHead struct.
+var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes()
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 46d8b0b42..a91f9f018 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -14,6 +14,14 @@
package linux
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
// This file contains structures required to support netfilter, specifically
// the iptables tool.
@@ -76,6 +84,8 @@ const (
// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTEntry struct {
// IP is used to filter packets based on the IP header.
IP IPTIP
@@ -112,21 +122,41 @@ type IPTEntry struct {
// SizeOfIPTEntry is the size of an IPTEntry.
const SizeOfIPTEntry = 112
-// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
-// struct marshaled via the binary package to write an IPTEntry to userspace.
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
+// KernelIPTEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
type KernelIPTEntry struct {
- IPTEntry
+ Entry IPTEntry
// Elems holds the data for all this rule's matches followed by the
// target. It is variable length -- users have to iterate over any
// matches and use TargetOffset and NextOffset to make sense of the
// data.
- Elems []byte
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTIP struct {
// Src is the source IP address.
Src InetAddr
@@ -189,6 +219,8 @@ const SizeOfIPTIP = 84
// XTCounters holds packet and byte counts for a rule. It corresponds to struct
// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTCounters struct {
// Pcnt is the packet count.
Pcnt uint64
@@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetinfo struct {
Name TableName
ValidHooks uint32
@@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetEntries struct {
Name TableName
Size uint32
@@ -350,13 +386,103 @@ type IPTGetEntries struct {
const SizeOfIPTGetEntries = 40
// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
-// Entrytable field. This struct marshaled via the binary package to write an
-// KernelIPTGetEntries to userspace.
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
type KernelIPTGetEntries struct {
IPTGetEntries
Entrytable []KernelIPTEntry
}
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIPTGetEntries) Packed() bool {
+ // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
+ // indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIPTGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
+
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
@@ -374,12 +500,6 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
-// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
-type KernelIPTReplace struct {
- IPTReplace
- Entries [0]IPTEntry
-}
-
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
@@ -392,6 +512,8 @@ func (en ExtensionName) String() string {
}
// TableName holds the name of a netfilter table.
+//
+// +marshal
type TableName [XT_TABLE_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 4a14ef691..c24a8216e 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -134,6 +134,15 @@ const (
SHUT_RDWR = 2
)
+// Packet types from <linux/if_packet.h>
+const (
+ PACKET_HOST = 0 // To us
+ PACKET_BROADCAST = 1 // To all
+ PACKET_MULTICAST = 2 // To group
+ PACKET_OTHERHOST = 3 // To someone else
+ PACKET_OUTGOING = 4 // Outgoing of any type
+)
+
// Socket options from socket.h.
const (
SO_DEBUG = 1
@@ -225,6 +234,8 @@ const (
const SockAddrMax = 128
// InetAddr is struct in_addr, from uapi/linux/in.h.
+//
+// +marshal
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
@@ -294,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
// Linger is struct linger, from include/linux/socket.h.
+//
+// +marshal
type Linger struct {
OnOff int32
Linger int32
@@ -308,6 +321,8 @@ const SizeOfLinger = 8
// the end of this struct or within existing unusued space, so its size grows
// over time. The current iteration is based on linux v4.17. New versions are
// always backwards compatible.
+//
+// +marshal
type TCPInfo struct {
State uint8
CaState uint8
@@ -405,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
// ControlMessageCredentials represents struct ucred from linux/socket.h.
+//
+// +marshal
type ControlMessageCredentials struct {
PID int32
UID uint32
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index daba8b172..fd95eb2d2 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -28,7 +28,14 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+
+ // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register.
+ TPIDR_EL0 uint64
+}
const (
// SyscallWidth is the width of insturctions.
@@ -101,9 +108,6 @@ type State struct {
// Our floating point state.
aarch64FPState `state:"wait"`
- // TLS pointer
- TPValue uint64
-
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
@@ -157,7 +161,6 @@ func (s *State) Fork() State {
return State{
Regs: s.Regs,
aarch64FPState: s.aarch64FPState.fork(),
- TPValue: s.TPValue,
FeatureSet: s.FeatureSet,
OrigR0: s.OrigR0,
}
@@ -241,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers {
return s.Regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
regs.UnmarshalUnsafe(buf)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
@@ -278,7 +281,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -291,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 3b3a0a272..1c3e3c14c 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
// PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
// u_debugreg, returning 0 or silently no-oping for other fields
// respectively.
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
@@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
if addr&7 != 0 || addr >= userStructSize {
return syscall.EIO
}
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index ada7ac7b8..cabbf60e0 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) {
// TLS returns the current TLS pointer.
func (c *context64) TLS() uintptr {
- return uintptr(c.TPValue)
+ return uintptr(c.Regs.TPIDR_EL0)
}
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
@@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool {
return false
}
- c.TPValue = uint64(value)
+ c.Regs.TPIDR_EL0 = uint64(value)
return true
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index dc458b37f..b9405b320 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -31,7 +31,11 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+}
// System-related constants for x86.
const (
@@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers {
return regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
@@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
}
regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// isUserSegmentSelector returns true if the given segment selector specifies a
@@ -543,7 +547,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
index 2018b978a..a91cae3ef 100644
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -132,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen
return sfd.inode.t.ld.outputQueueWrite(ctx, src)
}
-// Ioctl implements vfs.FileDescripionImpl.Ioctl.
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 3e00c2abb..737007748 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -6,13 +6,17 @@ go_library(
name = "fuse",
srcs = [
"dev.go",
+ "fusefs.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/log",
"//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index dc33268af..c9e12a94f 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -51,6 +51,9 @@ type DeviceFD struct {
vfs.DentryMetadataFileDescriptionImpl
vfs.NoLockFD
+ // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
+ mounted bool
+
// TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue
// and deque requests, control synchronization and establish communication
// between the FUSE kernel module and the /dev/fuse character device.
@@ -61,26 +64,51 @@ func (fd *DeviceFD) Release() {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
new file mode 100644
index 000000000..f7775fb9b
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -0,0 +1,200 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fuse implements fusefs.
+package fuse
+
+import (
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "fuse"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+type filesystemOptions struct {
+ // userID specifies the numeric uid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ userID uint32
+
+ // groupID specifies the numeric gid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ groupID uint32
+
+ // rootMode specifies the the file mode of the filesystem's root.
+ rootMode linux.FileMode
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // fuseFD is the FD returned when opening /dev/fuse. It is used for communication
+ // between the FUSE server daemon and the sentry fusefs.
+ fuseFD *DeviceFD
+
+ // opts is the options the fusefs is initialized with.
+ opts filesystemOptions
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var fsopts filesystemOptions
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ deviceDescriptorStr, ok := mopts["fd"]
+ if !ok {
+ log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "fd")
+
+ deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+
+ // Parse and set all the other supported FUSE mount options.
+ // TODO: Expand the supported mount options.
+ if userIDStr, ok := mopts["user_id"]; ok {
+ delete(mopts, "user_id")
+ userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.userID = uint32(userID)
+ }
+
+ if groupIDStr, ok := mopts["group_id"]; ok {
+ delete(mopts, "group_id")
+ groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.groupID = uint32(groupID)
+ }
+
+ rootMode := linux.FileMode(0777)
+ modeStr, ok := mopts["rootmode"]
+ if ok {
+ delete(mopts, "rootmode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode)
+ }
+ fsopts.rootMode = rootMode
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
+ // mount a FUSE filesystem.
+ fuseFD := fuseFd.Impl().(*DeviceFD)
+ fuseFD.mounted = true
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ fuseFD: fuseFD,
+ opts: fsopts,
+ }
+
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ // TODO: dispatch a FUSE_INIT request to the FUSE daemon server before
+ // returning. Mount will not block on this dispatched request.
+
+ // root is the fusefs root directory.
+ root := fs.newInode(creds, fsopts.rootMode)
+
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// Inode implements kernfs.Inode.
+type Inode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+ i := &Inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ i.dentry.Init(i)
+
+ return &i.dentry
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *Inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index cd5f5049e..00e3c99cd 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -150,11 +150,9 @@ afterSymlink:
return nil, err
}
if d != d.parent && !d.cachedMetadataAuthoritative() {
- _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask())
- if err != nil {
+ if err := d.parent.updateFromGetattr(ctx); err != nil {
return nil, err
}
- d.parent.updateFromP9Attrs(attrMask, &attr)
}
rp.Advance()
return d.parent, nil
@@ -209,17 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil
// Preconditions: As for getChildLocked. !parent.isSynthetic().
func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
+ if child != nil {
+ // Need to lock child.metadataMu because we might be updating child
+ // metadata. We need to hold the lock *before* getting metadata from the
+ // server and release it after updating local metadata.
+ child.metadataMu.Lock()
+ }
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
if err != nil && err != syserror.ENOENT {
+ if child != nil {
+ child.metadataMu.Unlock()
+ }
return nil, err
}
if child != nil {
if !file.isNil() && inoFromPath(qid.Path) == child.ino {
// The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
- child.updateFromP9Attrs(attrMask, &attr)
+ child.updateFromP9AttrsLocked(attrMask, &attr)
+ child.metadataMu.Unlock()
return child, nil
}
+ child.metadataMu.Unlock()
if file.isNil() && child.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to
// replace it.
@@ -1325,7 +1334,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
- if err := d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()); err != nil {
+ if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index b74d489a0..e20de84b5 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -785,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool {
// updateFromP9Attrs is called to update d's metadata after an update from the
// remote filesystem.
-func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
- d.metadataMu.Lock()
+// Precondition: d.metadataMu must be locked.
+func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
if mask.Mode {
if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
d.metadataMu.Unlock()
@@ -822,7 +822,6 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
if mask.Size {
d.updateFileSizeLocked(attr.Size)
}
- d.metadataMu.Unlock()
}
// Preconditions: !d.isSynthetic()
@@ -834,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
file p9file
handleMuRLocked bool
)
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
d.handleMu.RLock()
if !d.handle.file.isNil() {
file = d.handle.file
@@ -849,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
if err != nil {
return err
}
- d.updateFromP9Attrs(attrMask, &attr)
+ d.updateFromP9AttrsLocked(attrMask, &attr)
return nil
}
@@ -885,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.DevMinor = d.fs.devMinor
}
-func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -893,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -934,6 +938,17 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
if !d.isSynthetic() {
if stat.Mask != 0 {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // Check whether to allow a truncate request to be made.
+ switch d.mode & linux.S_IFMT {
+ case linux.S_IFREG:
+ // Allow.
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
if err := d.file.setAttr(ctx, p9.SetAttrMask{
Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
@@ -1495,7 +1510,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()); err != nil {
+ if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil {
return err
}
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index f10350c97..02317a133 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -155,26 +155,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported.
//
// TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if opts.Flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the fd was opened with O_APPEND, make sure the file size is updated.
+ // There is a possible race here if size is modified externally after
+ // metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Set offset to file size if the fd was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
+ n, err := fd.pwriteLocked(ctx, src, offset, opts)
+ return n, offset + n, err
+}
+// Preconditions: fd.dentry().metatdataMu must be locked.
+func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
d := fd.dentry()
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
@@ -237,8 +264,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index a7b53b2d2..811528982 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -16,6 +16,7 @@ package gofer
import (
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -144,7 +145,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
// hold here since specialFileFD doesn't client-cache data. Just buffer the
// read instead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
buf := make([]byte, dst.NumBytes())
@@ -176,39 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if fd.seekable && offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported.
//
// TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if opts.Flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the regular file fd was opened with O_APPEND, make sure the file size
+ // is updated. There is a possible race here if size is modified externally
+ // after metadata cache is updated.
+ if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
if fd.seekable {
+ // We need to hold the metadataMu *while* writing to a regular file.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Set offset to file size if the regular file was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
}
// Do a buffered write. See rationale in PRead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
// Don't do partial writes if we get a partial read from src.
if _, err := src.CopyIn(ctx, buf); err != nil {
- return 0, err
+ return 0, offset, err
}
n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
- return int64(n), err
+ finalOff = offset
+ // Update file size for regular files.
+ if fd.seekable {
+ finalOff += int64(n)
+ // d.metadataMu is already locked at this point.
+ if uint64(finalOff) > d.size {
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ atomic.StoreUint64(&d.size, uint64(finalOff))
+ }
+ }
+ return int64(n), finalOff, err
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -218,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
}
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 1a88cb657..c894f2ca0 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -373,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
// SetStat implements kernfs.Inode.
func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- s := opts.Stat
+ s := &opts.Stat
m := s.Mask
if m == 0 {
@@ -386,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if err := syscall.Fstat(i.hostFD, &hostStat); err != nil {
return err
}
- if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
return err
}
@@ -396,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
}
if m&linux.STATX_SIZE != 0 {
+ if hostStat.Mode&linux.S_IFMT != linux.S_IFREG {
+ return syserror.EINVAL
+ }
if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 3f0aea73a..1d37ccb98 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -112,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence
return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
}
-// Release implements vfs.FileDecriptionImpl.Release.
+// Release implements vfs.FileDescriptionImpl.Release.
func (fd *GenericDirectoryFD) Release() {}
func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
@@ -123,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
}
-// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
// o.mu when calling cb.
func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
fd.mu.Lock()
@@ -198,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
return err
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
defer fd.mu.Unlock()
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 2ab3f1761..579e627f0 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -267,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index ff82e1f20..6b705e955 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -1104,7 +1104,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
mnt := rp.Mount()
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
index a3c1f7a8d..c0749e711 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -151,7 +151,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux
func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
d := fd.dentry()
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
mnt := fd.vfsfd.Mount()
@@ -176,7 +176,7 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return nil
}
-// StatFS implements vfs.FileDesciptionImpl.StatFS.
+// StatFS implements vfs.FileDescriptionImpl.StatFS.
func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
return fd.filesystem().statFS(ctx)
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index dad4db1a7..79c2725f3 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -128,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac
return fd.GenericDirectoryFD.IterDirents(ctx, cb)
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
if fd.task.ExitState() >= kernel.TaskExitZombie {
return 0, syserror.ENOENT
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index a0f20c2d4..ef210a69b 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -649,7 +649,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.mu.RUnlock()
return err
}
- if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil {
+ if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil {
fs.mu.RUnlock()
return err
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 1cdb46e6f..abbaa5d60 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -325,8 +325,15 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset and error. The
+// final offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
@@ -334,40 +341,44 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
//
// TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
}
srclen := src.NumBytes()
if srclen == 0 {
- return 0, nil
+ return 0, offset, nil
}
f := fd.inode().impl.(*regularFile)
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ // If the file is opened with O_APPEND, update offset to file size.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Locking f.inode.mu is sufficient for reading f.size.
+ offset = int64(f.size)
+ }
if end := offset + srclen; end < offset {
// Overflow.
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
- var err error
srclen, err = vfs.CheckLimit(ctx, offset, srclen)
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(srclen)
- f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
- fd.inode().touchCMtimeLocked()
- f.inode.mu.Unlock()
+ f.inode.touchCMtimeLocked()
putRegularFileReadWriter(rw)
- return n, err
+ return n, n + offset, err
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.offMu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index d7f4f0779..2545d88e9 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -452,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) {
}
}
-func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -460,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&i.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
return err
}
i.mu.Lock()
@@ -695,7 +696,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
d := fd.dentry()
- if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil {
+ if err := d.inode.setStat(ctx, creds, &opts); err != nil {
return err
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 25fe1921b..f6886a758 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -132,6 +132,7 @@ go_library(
"task_stop.go",
"task_syscall.go",
"task_usermem.go",
+ "task_work.go",
"thread_group.go",
"threads.go",
"timekeeper.go",
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 732e66da4..bcc1b29a8 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3
}
}
-// UnlockPI unlock the futex following the Priority-inheritance futex
-// rules. The address provided must contain the caller's TID. If there are
-// waiters, TID of the next waiter (FIFO) is set to the given address, and the
-// waiter woken up. If there are no waiters, 0 is set to the address.
+// UnlockPI unlocks the futex following the Priority-inheritance futex rules.
+// The address provided must contain the caller's TID. If there are waiters,
+// TID of the next waiter (FIFO) is set to the given address, and the waiter
+// woken up. If there are no waiters, 0 is set to the address.
func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
k, err := getKey(t, addr, private)
if err != nil {
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 240cd6fe0..15dae0f5b 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -1469,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 {
return now
}
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ return ktime.TcpipAfterFunc(k.realtimeClock, d, f)
+}
+
// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
// LoadFrom.
func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index f48247c94..c4db05bd8 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -68,6 +68,21 @@ type Task struct {
// runState is exclusive to the task goroutine.
runState taskRunState
+ // taskWorkCount represents the current size of the task work queue. It is
+ // used to avoid acquiring taskWorkMu when the queue is empty.
+ //
+ // Must accessed with atomic memory operations.
+ taskWorkCount int32
+
+ // taskWorkMu protects taskWork.
+ taskWorkMu sync.Mutex `state:"nosave"`
+
+ // taskWork is a queue of work to be executed before resuming user execution.
+ // It is similar to the task_work mechanism in Linux.
+ //
+ // taskWork is exclusive to the task goroutine.
+ taskWork []TaskWorker
+
// haveSyscallReturn is true if tc.Arch().Return() represents a value
// returned by a syscall (or set by ptrace after a syscall).
//
@@ -550,6 +565,10 @@ type Task struct {
// futexWaiter is exclusive to the task goroutine.
futexWaiter *futex.Waiter `state:"nosave"`
+ // robustList is a pointer to the head of the tasks's robust futex
+ // list.
+ robustList usermem.Addr
+
// startTime is the real time at which the task started. It is set when
// a Task is created or invokes execve(2).
//
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 9b69f3cbe..7803b98d0 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -207,6 +207,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
return flags.CloseOnExec
})
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// NOTE(b/30815691): We currently do not implement privileged
// executables (set-user/group-ID bits and file capabilities). This
// allows us to unconditionally enable user dumpability on the new mm.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c4ade6e8e..231ac548a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState {
}
}
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// Deactivate the address space and update max RSS before releasing the
// task's MM.
t.Deactivate()
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index a53e77c9f..4b535c949 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -15,6 +15,7 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
return t.MemoryManager().GetSharedFutexKey(t, addr)
}
+
+// GetRobustList sets the robust futex list for the task.
+func (t *Task) GetRobustList() usermem.Addr {
+ t.mu.Lock()
+ addr := t.robustList
+ t.mu.Unlock()
+ return addr
+}
+
+// SetRobustList sets the robust futex list for the task.
+func (t *Task) SetRobustList(addr usermem.Addr) {
+ t.mu.Lock()
+ t.robustList = addr
+ t.mu.Unlock()
+}
+
+// exitRobustList walks the robust futex list, marking locks dead and notifying
+// wakers. It corresponds to Linux's exit_robust_list(). Following Linux,
+// errors are silently ignored.
+func (t *Task) exitRobustList() {
+ t.mu.Lock()
+ addr := t.robustList
+ t.robustList = 0
+ t.mu.Unlock()
+
+ if addr == 0 {
+ return
+ }
+
+ var rl linux.RobustListHead
+ if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil {
+ return
+ }
+
+ next := rl.List
+ done := 0
+ var pendingLockAddr usermem.Addr
+ if rl.ListOpPending != 0 {
+ pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset)
+ }
+
+ // Wake up normal elements.
+ for usermem.Addr(next) != addr {
+ // We traverse to the next element of the list before we
+ // actually wake anything. This prevents the race where waking
+ // this futex causes a modification of the list.
+ thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+
+ // Try to decode the next element in the list before waking the
+ // current futex. But don't check the error until after we've
+ // woken the current futex. Linux does it in this order too
+ _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+
+ // Wakeup the current futex if it's not pending.
+ if thisLockAddr != pendingLockAddr {
+ t.wakeRobustListOne(thisLockAddr)
+ }
+
+ // If there was an error copying the next futex, we must bail.
+ if nextErr != nil {
+ break
+ }
+
+ // This is a user structure, so it could be a massive list, or
+ // even contain a loop if they are trying to mess with us. We
+ // cap traversal to prevent that.
+ done++
+ if done >= linux.ROBUST_LIST_LIMIT {
+ break
+ }
+ }
+
+ // Is there a pending entry to wake?
+ if pendingLockAddr != 0 {
+ t.wakeRobustListOne(pendingLockAddr)
+ }
+}
+
+// wakeRobustListOne wakes a single futex from the robust list.
+func (t *Task) wakeRobustListOne(addr usermem.Addr) {
+ // Bit 0 in address signals PI futex.
+ pi := addr&1 == 1
+ addr = addr &^ 1
+
+ // Load the futex.
+ f, err := t.LoadUint32(addr)
+ if err != nil {
+ // Can't read this single value? Ignore the problem.
+ // We can wake the other futexes in the list.
+ return
+ }
+
+ tid := uint32(t.ThreadID())
+ for {
+ // Is this held by someone else?
+ if f&linux.FUTEX_TID_MASK != tid {
+ return
+ }
+
+ // This thread is dying and it's holding this futex. We need to
+ // set the owner died bit and wake up any waiters.
+ newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED
+ if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil {
+ return
+ } else if curF != f {
+ // Futex changed out from under us. Try again...
+ f = curF
+ continue
+ }
+
+ // Wake waiters if there are any.
+ if f&linux.FUTEX_WAITERS != 0 {
+ private := f&linux.FUTEX_PRIVATE_FLAG != 0
+ if pi {
+ t.Futex().UnlockPI(t, addr, tid, private)
+ return
+ }
+ t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1)
+ }
+
+ // Done.
+ return
+ }
+}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d654dd997..7d4f44caf 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState {
return (*runInterrupt)(nil)
}
- // We're about to switch to the application again. If there's still a
+ // Execute any task work callbacks before returning to user space.
+ if atomic.LoadInt32(&t.taskWorkCount) > 0 {
+ t.taskWorkMu.Lock()
+ queue := t.taskWork
+ t.taskWork = nil
+ atomic.StoreInt32(&t.taskWorkCount, 0)
+ t.taskWorkMu.Unlock()
+
+ // Do not hold taskWorkMu while executing task work, which may register
+ // more work.
+ for _, work := range queue {
+ work.TaskWork(t)
+ }
+ }
+
+ // We're about to switch to the application again. If there's still an
// unhandled SyscallRestartErrno that wasn't translated to an EINTR,
// restart the syscall that was interrupted. If there's a saved signal
// mask, restore it. (Note that restoring the saved signal mask may unblock
diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go
new file mode 100644
index 000000000..dda5a433a
--- /dev/null
+++ b/pkg/sentry/kernel/task_work.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync/atomic"
+
+// TaskWorker is a deferred task.
+//
+// This must be savable.
+type TaskWorker interface {
+ // TaskWork will be executed prior to returning to user space. Note that
+ // TaskWork may call RegisterWork again, but this will not be executed until
+ // the next return to user space, unlike in Linux. This effectively allows
+ // registration of indefinite user return hooks, but not by default.
+ TaskWork(t *Task)
+}
+
+// RegisterWork can be used to register additional task work that will be
+// performed prior to returning to user space. See TaskWorker.TaskWork for
+// semantics regarding registration.
+func (t *Task) RegisterWork(work TaskWorker) {
+ t.taskWorkMu.Lock()
+ defer t.taskWorkMu.Unlock()
+ atomic.AddInt32(&t.taskWorkCount, 1)
+ t.taskWork = append(t.taskWork, work)
+}
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 7ba7dc50c..2817aa3ba 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "time",
srcs = [
"context.go",
+ "tcpip.go",
"time.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go
new file mode 100644
index 000000000..c4474c0cf
--- /dev/null
+++ b/pkg/sentry/kernel/time/tcpip.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "sync"
+ "time"
+)
+
+// TcpipAfterFunc waits for duration to elapse according to clock then runs fn.
+// The timer is started immediately and will fire exactly once.
+func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer {
+ timer := &TcpipTimer{
+ clock: clock,
+ }
+ timer.notifier = functionNotifier{
+ fn: func() {
+ // tcpip.Timer.Stop() explicitly states that the function is called in a
+ // separate goroutine that Stop() does not synchronize with.
+ // Timer.Destroy() synchronizes with calls to TimerListener.Notify().
+ // This is semantically meaningful because, in the former case, it's
+ // legal to call tcpip.Timer.Stop() while holding locks that may also be
+ // taken by the function, but this isn't so in the latter case. Most
+ // immediately, Timer calls TimerListener.Notify() while holding
+ // Timer.mu. A deadlock occurs without spawning a goroutine:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify()
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock!
+ //
+ // Spawning a goroutine avoids the deadlock:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify() <- Launches T2
+ // T2:
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, blocks
+ // T1:
+ // => (returns) <- Timer.mu.Unlock() called
+ // T2:
+ // => (continues) <- No deadlock!
+ go func() {
+ timer.Stop()
+ fn()
+ }()
+ },
+ }
+ timer.Reset(duration)
+ return timer
+}
+
+// TcpipTimer is a resettable timer with variable duration expirations.
+// Implements tcpip.Timer, which does not define a Destroy method; instead, all
+// resources are released after timer expiration and calls to Timer.Stop.
+//
+// Must be created by AfterFunc.
+type TcpipTimer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // notifier is called when the Timer expires. notifier is immutable.
+ notifier functionNotifier
+
+ // mu protects t.
+ mu sync.Mutex
+
+ // t stores the latest running Timer. This is replaced whenever Reset is
+ // called since Timer cannot be restarted once it has been Destroyed by Stop.
+ //
+ // This field is nil iff Stop has been called.
+ t *Timer
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (r *TcpipTimer) Stop() bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ return false
+ }
+ _, lastSetting := r.t.Swap(Setting{})
+ r.t.Destroy()
+ r.t = nil
+ return lastSetting.Enabled
+}
+
+// Reset implements tcpip.Timer.Reset.
+func (r *TcpipTimer) Reset(d time.Duration) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ r.t = NewTimer(r.clock, &r.notifier)
+ }
+
+ r.t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: r.clock.Now().Add(d),
+ })
+}
+
+// functionNotifier is a TimerListener that runs a function.
+//
+// functionNotifier cannot be saved or loaded.
+type functionNotifier struct {
+ fn func()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) {
+ f.fn()
+ return Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (f *functionNotifier) Destroy() {}
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index ccacaea6b..fca3a5478 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -58,7 +58,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate &= ^uint64(UserFlagsClear)
regs.Pstate |= UserFlagsSet
+
+ SetTLS(regs.TPIDR_EL0)
+
kernelExitToEl0()
+
+ regs.TPIDR_EL0 = GetTLS()
+
vector = c.vecCode
// Perform the switch.
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c40c6d673..c0fd3425b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,5 +20,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index ff81ea6e6..e76e498de 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -40,6 +40,8 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index a92aed2c9..ec5506efc 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -36,6 +36,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
@@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
if err != nil {
return nil, syserr.FromError(err)
}
- return opt, nil
+ optP := primitive.ByteSlice(opt)
+ return &optP, nil
}
// SetSockOpt implements socket.Socket.SetSockOpt.
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index f7abe77d3..d9394055d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
// Each rule corresponds to an entry.
entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
+ Entry: linux.IPTEntry{
IP: linux.IPTIP{
Protocol: uint16(rule.Filter.Protocol),
},
@@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
TargetOffset: linux.SizeOfIPTEntry,
},
}
- copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
- copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
- copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src)
- copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask)
- copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
- copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
if rule.Filter.DstInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
}
if rule.Filter.SrcInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
}
if rule.Filter.OutputInterfaceInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
}
for _, matcher := range rule.Matchers {
@@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
}
// Serialize and append the target.
@@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
nflog("convert to binary: adding entry: %+v", entry)
- entries.Size += uint32(entry.NextOffset)
+ entries.Size += uint32(entry.Entry.NextOffset)
entries.Entrytable = append(entries.Entrytable, entry)
info.NumEntries++
}
@@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.TablenameFilter:
+ case stack.FilterTable:
table = stack.EmptyFilterTable()
- case stack.TablenameNat:
- table = stack.EmptyNatTable()
+ case stack.NATTable:
+ table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
+ table.BuiltinChains[hk] = stack.HookUnset
+ table.Underflows[hk] = stack.HookUnset
for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
@@ -456,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(stack.UserChainTarget)
- if !ok {
+ if _, ok := rule.Target.(stack.UserChainTarget); !ok {
continue
}
@@ -473,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
nflog("user chain's first node must have no matchers")
return syserr.ErrInvalidArgument
}
- table.UserChains[target.Name] = ruleIdx + 1
}
// Set each jump to point to the appropriate rule. Right now they hold byte
@@ -499,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
// make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if ruleIdx == stack.HookUnset {
+ continue
+ }
if !isUnconditionalAccept(table.Rules[ruleIdx]) {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
@@ -512,9 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- stk.IPTables().ReplaceTable(replace.Name.String(), table)
-
- return nil
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index d5ca3ac56..0546801bf 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -36,6 +36,8 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 81f34c5a2..98ca7add0 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
@@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- return int32(s.sendBufferSize), nil
+ sendBufferSizeP := primitive.Int32(s.sendBufferSize)
+ return &sendBufferSizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- return int32(math.MaxInt32), nil
+ recvBufferSizeP := primitive.Int32(math.MaxInt32)
+ return &recvBufferSizeP, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var passcred int32
+ var passcred primitive.Int32
if s.Passcred() {
passcred = 1
}
- return passcred, nil
+ return &passcred, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ea6ebd0e2..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 49a04e613..9856ab8c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
@@ -61,6 +62,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -909,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -919,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -955,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -970,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -980,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -1013,7 +1016,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1024,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -1034,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -1049,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1065,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -1081,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
@@ -1092,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
@@ -1103,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1111,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
if v == 0 {
- return []byte{}, nil
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
@@ -1126,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// interface was removed.
return nil, syserr.ErrUnknownDevice
}
- return append([]byte(nic.Name), 0), nil
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1137,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
@@ -1148,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1162,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1170,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1182,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1193,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1202,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
@@ -1213,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(!v), nil
+
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
@@ -1224,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
@@ -1235,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1246,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1258,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1270,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
case linux.TCP_KEEPCNT:
if outLen < sizeOfInt32 {
@@ -1282,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
@@ -1294,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Millisecond), nil
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1308,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1343,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
case linux.TCP_LINGER2:
if outLen < sizeOfInt32 {
@@ -1355,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
if outLen < sizeOfInt32 {
@@ -1367,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
case linux.TCP_SYNCNT:
if outLen < sizeOfInt32 {
@@ -1378,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_WINDOW_CLAMP:
if outLen < sizeOfInt32 {
@@ -1390,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1399,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1410,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1418,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
case linux.IPV6_RECVTCLASS:
if outLen < sizeOfInt32 {
@@ -1443,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1452,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1465,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
@@ -1481,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1495,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1506,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_RECVTOS:
if outLen < sizeOfInt32 {
@@ -1531,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1542,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIP(t, name)
@@ -2468,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -2526,6 +2599,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
switch v := addr.(type) {
case *linux.SockAddrLink:
v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index d65a89316..a9025b0ec 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fcd7f9d7f..d112757fb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
@@ -86,7 +87,7 @@ type SocketOps interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cca5e70f1..061a689a9 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -35,5 +35,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4bb2b6ff4..0482d33cf 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -40,6 +40,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index ff2149250..05c16fcfe 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
@@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 217fcfef2..4a9b04fd0 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -99,5 +99,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ea4f9b1a7..80c65164a 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{
270: syscalls.Supported("pselect", Pselect),
271: syscalls.Supported("ppoll", Ppoll),
272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 273: syscalls.Supported("set_robust_list", SetRobustList),
+ 274: syscalls.Supported("get_robust_list", GetRobustList),
275: syscalls.Supported("splice", Splice),
276: syscalls.Supported("tee", Tee),
277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b68261f72..f04d78856 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
switch cmd {
case linux.FUTEX_WAIT:
// WAIT uses a relative timeout.
- mask = ^uint32(0)
+ mask = linux.FUTEX_BITSET_MATCH_ANY
var timeoutDur time.Duration
if !forever {
timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
@@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.ENOSYS
}
}
+
+// SetRobustList implements linux syscall set_robust_list(2).
+func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ head := args[0].Pointer()
+ length := args[1].SizeT()
+
+ if length != uint(linux.SizeOfRobustListHead) {
+ return 0, nil, syserror.EINVAL
+ }
+ t.SetRobustList(head)
+ return 0, nil, nil
+}
+
+// GetRobustList implements linux syscall get_robust_list(2).
+func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ tid := args[0].Int()
+ head := args[1].Pointer()
+ size := args[2].Pointer()
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ot := t
+ if tid != 0 {
+ if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // Copy out head pointer.
+ if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out size, which is a constant.
+ if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0760af77b..414fce8e3 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 0c740335b..64696b438 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -72,5 +72,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index adeaa39cc..ea337de7c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Silently allow MS_NOSUID, since we don't implement set-id bits
// anyway.
- const unsupportedFlags = linux.MS_NODEV |
- linux.MS_NODIRATIME | linux.MS_STRICTATIME
+ const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME
// Linux just allows passing any flags to mount(2) - it won't fail when
// unknown or unsupported flags are passed. Since we don't implement
@@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
opts.Flags.NoExec = true
}
+ if flags&linux.MS_NODEV == linux.MS_NODEV {
+ opts.Flags.NoDev = true
+ }
+ if flags&linux.MS_NOSUID == linux.MS_NOSUID {
+ opts.Flags.NoSUID = true
+ }
if flags&linux.MS_RDONLY == linux.MS_RDONLY {
opts.ReadOnly = true
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 09ecfed26..6daedd173 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -178,6 +178,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
+ NeedWritePerm: true,
})
return 0, nil, handleSetSizeError(t, err)
}
@@ -197,6 +198,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef()
+ if !file.IsWritable() {
+ return 0, nil, syserror.EINVAL
+ }
+
err := file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 10b668477..8096a8f9c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 945a364a7..63ab11f8c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -15,12 +15,15 @@
package vfs2
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Move data.
var (
- n int64
- err error
- inCh chan struct{}
- outCh chan struct{}
+ n int64
+ err error
)
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
for {
// If both input and output are pipes, delegate to the pipe
- // implementation. Otherwise, exactly one end is a pipe, which we
- // ensure is consistently ordered after the non-pipe FD's locks by
- // passing the pipe FD as usermem.IO to the non-pipe end.
+ // implementation. Otherwise, exactly one end is a pipe, which
+ // we ensure is consistently ordered after the non-pipe FD's
+ // locks by passing the pipe FD as usermem.IO to the non-pipe
+ // end.
switch {
case inIsPipe && outIsPipe:
n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
@@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
} else {
n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
+ default:
+ panic("not possible")
}
+
if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
break
}
-
- // Note that the blocking behavior here is a bit different than the
- // normal pattern. Because we need to have both data to read and data
- // to write simultaneously, we actually explicitly block on both of
- // these cases in turn before returning to the splice operation.
- if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
- if inCh == nil {
- inCh = make(chan struct{}, 1)
- inW, _ := waiter.NewChannelEntry(inCh)
- inFile.EventRegister(&inW, eventMaskRead)
- defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(inCh); err != nil {
- break
- }
- }
- if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, eventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(outCh); err != nil {
- break
- }
+ if err = dw.waitForBoth(t); err != nil {
+ break
}
}
@@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
// Copy data.
var (
- inCh chan struct{}
- outCh chan struct{}
+ n int64
+ err error
)
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
for {
- n, err := pipe.Tee(t, outPipeFD, inPipeFD, count)
- if n != 0 {
- return uintptr(n), nil, nil
+ n, err = pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+ if n == 0 {
+ return 0, nil, err
+ }
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := args[0].Int()
+ inFD := args[1].Int()
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ if !inFile.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+ if !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that the outFile Append flag is not set.
+ if outFile.StatusFlags()&linux.O_APPEND != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Verify that inFile is a regular file or block device. This is a
+ // requirement; the same check appears in Linux
+ // (fs/splice.c:splice_direct_to_actor).
+ if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil {
+ return 0, nil, err
+ } else if stat.Mask&linux.STATX_TYPE == 0 ||
+ (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy offset if it exists.
+ offset := int64(-1)
+ if offsetAddr != 0 {
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.ESPIPE
}
- if err != syserror.ErrWouldBlock || nonBlock {
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
return 0, nil, err
}
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if offset+count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Validate count. This must come after offset checks.
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
- // Note that the blocking behavior here is a bit different than the
- // normal pattern. Because we need to have both data to read and data
- // to write simultaneously, we actually explicitly block on both of
- // these cases in turn before returning to the tee operation.
- if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
- if inCh == nil {
- inCh = make(chan struct{}, 1)
- inW, _ := waiter.NewChannelEntry(inCh)
- inFile.EventRegister(&inW, eventMaskRead)
- defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ // Reading from input file should never block, since it is regular or
+ // block device. We only need to check if writing to the output file
+ // can block.
+ nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0
+ if outIsPipe {
+ for n < count {
+ var spliceN int64
+ if offset != -1 {
+ spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{})
+ offset += spliceN
+ } else {
+ spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
- if err := t.Block(inCh); err != nil {
- return 0, nil, err
+ n += spliceN
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
}
}
- if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, eventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
+ } else {
+ // Read inFile to buffer, then write the contents to outFile.
+ buf := make([]byte, count)
+ for n < count {
+ var readN int64
+ if offset != -1 {
+ readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{})
+ offset += readN
+ } else {
+ readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ }
+ if readN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the
+ // error and exit the loop.
+ err = nil
+ break
}
- if err := t.Block(outCh); err != nil {
- return 0, nil, err
+ n += readN
+ if err != nil {
+ break
+ }
+
+ // Write all of the bytes that we read. This may need
+ // multiple write calls to complete.
+ wbuf := buf[:n]
+ for len(wbuf) > 0 {
+ var writeN int64
+ writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{})
+ wbuf = wbuf[writeN:]
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForOut(t)
+ }
+ if err != nil {
+ // We didn't complete the write. Only
+ // report the bytes that were actually
+ // written, and rewind the offset.
+ notWritten := int64(len(wbuf))
+ n -= notWritten
+ if offset != -1 {
+ offset -= notWritten
+ }
+ break
+ }
+ }
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
}
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ if offsetAddr != 0 {
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
+// thread-safe, and does not take a reference on the vfs.FileDescriptions.
+//
+// Users must call destroy() when finished.
+type dualWaiter struct {
+ inFile *vfs.FileDescription
+ outFile *vfs.FileDescription
+
+ inW waiter.Entry
+ inCh chan struct{}
+ outW waiter.Entry
+ outCh chan struct{}
+}
+
+// waitForBoth waits for both dw.inFile and dw.outFile to be ready.
+func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
+ if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if dw.inCh == nil {
+ dw.inW, dw.inCh = waiter.NewChannelEntry(nil)
+ dw.inFile.EventRegister(&dw.inW, eventMaskRead)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.inCh); err != nil {
+ return err
+ }
+ }
+ return dw.waitForOut(t)
+}
+
+// waitForOut waits for dw.outfile to be read.
+func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
+ if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready now. Try again before blocking.
+ return nil
}
+ if err := t.Block(dw.outCh); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// destroy cleans up resources help by dw. No more calls to wait* can occur
+// after destroy is called.
+func (dw *dualWaiter) destroy() {
+ if dw.inCh != nil {
+ dw.inFile.EventUnregister(&dw.inW)
+ dw.inCh = nil
+ }
+ if dw.outCh != nil {
+ dw.outFile.EventUnregister(&dw.outW)
+ dw.outCh = nil
}
+ dw.inFile = nil
+ dw.outFile = nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 8f497ecc7..1b2cfad7d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -44,7 +44,7 @@ func Override() {
s.Table[23] = syscalls.Supported("select", Select)
s.Table[32] = syscalls.Supported("dup", Dup)
s.Table[33] = syscalls.Supported("dup2", Dup2)
- delete(s.Table, 40) // sendfile
+ s.Table[40] = syscalls.Supported("sendfile", Sendfile)
s.Table[41] = syscalls.Supported("socket", Socket)
s.Table[42] = syscalls.Supported("connect", Connect)
s.Table[43] = syscalls.Supported("accept", Accept)
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index f223aeda8..dfc8573fd 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -79,6 +79,17 @@ type MountFlags struct {
// NoATime is equivalent to MS_NOATIME and indicates that the
// filesystem should not update access time in-place.
NoATime bool
+
+ // NoDev is equivalent to MS_NODEV and indicates that the
+ // filesystem should not allow access to devices (special files).
+ // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE
+ // filesystems.
+ NoDev bool
+
+ // NoSUID is equivalent to MS_NOSUID and indicates that the
+ // filesystem should not honor set-user-ID and set-group-ID bits or
+ // file capabilities when executing programs.
+ NoSUID bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
@@ -153,6 +164,12 @@ type SetStatOptions struct {
// == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask
// instead).
Stat linux.Statx
+
+ // NeedWritePerm indicates that write permission on the file is needed for
+ // this operation. This is needed for truncate(2) (note that ftruncate(2)
+ // does not require the same check--instead, it checks that the fd is
+ // writable).
+ NeedWritePerm bool
}
// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 9cb050597..33389c1df 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -183,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ stat := &opts.Stat
if stat.Mask&linux.STATX_SIZE != 0 {
limit, err := CheckLimit(ctx, 0, int64(stat.Size))
if err != nil {
@@ -215,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat
return syserror.EPERM
}
}
+ if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
if !CanActAsOwner(creds, kuid) {
if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 7908c5744..1a631b31a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -72,6 +72,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4TTLExceeded = 0
+ ICMPv4HostUnreachable = 1
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index c7ee2de57..a13b4b809 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -110,9 +110,16 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// Values for ICMP destination unreachable code as defined in RFC 4443 section
+// 3.1.
const (
- ICMPv6PortUnreachable = 4
+ ICMPv6NetworkUnreachable = 0
+ ICMPv6Prohibited = 1
+ ICMPv6BeyondScope = 2
+ ICMPv6AddressUnreachable = 3
+ ICMPv6PortUnreachable = 4
+ ICMPv6Policy = 5
+ ICMPv6RejectRoute = 6
)
// Type is the ICMP type field.
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a2bb773d4..e12a5929b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -302,3 +302,7 @@ func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 6aa1badc7..c18bb91fb 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -386,26 +386,33 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
var builder iovec.Builder
@@ -448,22 +455,8 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var eth header.Ethernet
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- eth = make(header.Ethernet, header.EthernetMinimumSize)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -493,7 +486,6 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
var builder iovec.Builder
builder.Add(vnetHdrBuf)
- builder.Add(eth)
builder.Add(pkt.Header.View())
for _, v := range pkt.Data.Views() {
builder.Add(v)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 4bad930c7..7b995b85a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -510,6 +514,10 @@ func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAdd
d.pkts = append(d.pkts, pkt)
}
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestDispatchPacketFormat(t *testing.T) {
for _, test := range []struct {
name string
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 3b17d8c28..781cdd317 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -118,3 +118,6 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c305d9e86..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -135,6 +135,10 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unsupported operation")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 328bd048e..d40de54df 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -61,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -135,3 +145,8 @@ func (e *Endpoint) GSOMaxSize() uint32 {
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
index c1a219f02..7d9249c1c 100644
--- a/pkg/tcpip/link/nested/nested_test.go
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd
d.count++
}
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNestedLinkEndpoint(t *testing.T) {
const emptyAddress = tcpip.LinkAddress("")
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index c84fe1bb9..467083239 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -107,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
@@ -194,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -213,3 +220,8 @@ func (e *endpoint) Wait() {
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index a36862c67..507c76b76 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
v := pkt.Data.ToView()
// Transmit the packet.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 28a2e88ba..8f3cd9449 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index d9cd4e83a..509076643 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 47446efec..04ae58e59 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -272,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader)
}
// Append upper headers.
@@ -366,3 +354,30 @@ func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
}
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 24a8dc2eb..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -60,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -153,3 +162,8 @@ func (e *Endpoint) Wait() {}
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ffb2354be..c448a888f 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -40,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress,
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -90,6 +94,11 @@ func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a5b780ca2..615bae648 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -185,6 +185,11 @@ func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 1b67aa066..83e71cb8c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 3b79749b5..ff1cb53dd 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index d39baf620..559a1c4dd 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -49,7 +49,8 @@ const (
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
@@ -113,13 +114,11 @@ type conn struct {
// update the state of tcb. It is immutable.
tcbHook Hook
- // mu protects tcb.
+ // mu protects all mutable state.
mu sync.Mutex `state:"nosave"`
-
// tcb is TCB control block. It is used to keep track of states
// of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
-
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection. It is protected by mu.
lastUsed time.Time `state:".(unixTime)"`
@@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool {
return now.Sub(cn.lastUsed) > defaultTimeout
}
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
+}
+
// ConnTrack tracks all connections created for NAT rules. Most users are
-// expected to only call handlePacket and createConnFor.
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
//
// ConnTrack keeps all connections in a slice of buckets, each of which holds a
// linked list of tuples. This gives us some desirable properties:
@@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
return nil, dirOriginal
}
-// createConnFor creates a new conn for pkt.
-func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
manip = manipDstOutput
}
conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
// Lock the buckets in the correct order.
- tupleBucket := ct.bucket(tid)
- replyBucket := ct.bucket(replyTID)
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
ct.mu.RLock()
defer ct.mu.RUnlock()
if tupleBucket < replyBucket {
@@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
ct.buckets[tupleBucket].mu.Lock()
}
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
ct.buckets[replyBucket].mu.Unlock()
}
-
- return conn
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
}
// handlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
if conn == nil {
- // Connection not found for the packet or the packet is invalid.
- return
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return false
}
switch hook {
@@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else if hook == conn.tcbHook {
- conn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
- conn.tcb.UpdateStateInbound(tcpHeader)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
+ }
+
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
}
// bucket gets the conntrack bucket for a tupleID.
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index eefb4b07f..bca1d940b 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -307,6 +307,11 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index bbf3b60e8..cbbae4224 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,22 +22,30 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -48,11 +56,9 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() *IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
return &IPTables{
- tables: map[string]Table{
- TablenameNat: Table{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -60,60 +66,66 @@ func DefaultTables() *IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -127,51 +139,48 @@ func DefaultTables() *IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// GetTable returns table by name.
+// GetTable returns a table by name.
func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
it.mu.RLock()
defer it.mu.RUnlock()
- t, ok := it.tables[name]
- return t, ok
+ return it.tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) {
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -181,14 +190,8 @@ func (it *IPTables) ReplaceTable(name string, table Table) {
it.startReaper(reaperDelay)
}
it.modified = true
- it.tables[name] = table
-}
-
-// GetPriorities returns slice of priorities associated with hook.
-func (it *IPTables) GetPriorities(hook Hook) []string {
- it.mu.RLock()
- defer it.mu.RUnlock()
- return it.priorities[hook]
+ it.tables[id] = table
+ return nil
}
// A chainVerdict is what a table decides should be done with a packet.
@@ -223,11 +226,19 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.GetPriorities(hook) {
- table, _ := it.GetTable(tablename)
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -256,6 +267,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d43f60c67..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.createConnFor(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index eb70e3104..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -84,14 +84,14 @@ type IPTables struct {
// mu protects tables, priorities, and modified.
mu sync.RWMutex
- // tables maps table names to tables. User tables have arbitrary names.
- // mu needs to be locked for accessing.
- tables map[string]Table
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. mu needs to be locked for accessing.
- priorities map[Hook][]string
+ priorities [NumHooks][]tableID
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
@@ -113,22 +113,20 @@ type Table struct {
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index e28c23d66..9dce11a97 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -469,7 +469,7 @@ type ndpState struct {
rtrSolicit struct {
// The timer used to send the next router solicitation message.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the Router Solicitation timer know that it has been stopped.
//
@@ -503,7 +503,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -515,38 +515,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -561,15 +561,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
ndp.nic.mu.Lock()
defer ndp.nic.mu.Unlock()
@@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
ndp.nic.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6f86abc98..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
@@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
@@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7b80534e6..fea0ce7e8 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 3bc9fd831..c477e31d8 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -89,6 +89,11 @@ func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index e3556d5d2..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -79,6 +79,10 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f260eeb7f..9e1b2d25f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -52,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -330,8 +333,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -342,6 +344,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -443,6 +455,9 @@ type LinkEndpoint interface {
// See:
// https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 2b7ece851..a6faa22c2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -728,6 +728,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -801,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 48ad56d4d..21aafb0a2 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -203,6 +203,31 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
@@ -316,6 +341,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -555,6 +602,9 @@ type Endpoint interface {
type LinkPacketInfo struct {
// Protocol is the NetworkProtocolNumber for the packet.
Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
}
// PacketEndpoint are additional methods that are only implemented by Packet
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 7f172f978..f32d58091 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -20,7 +20,7 @@
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
@@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index 5554c573f..f1dd7c310 100644
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
@@ -20,50 +20,49 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-// cancellableTimerInstance is a specific instance of CancellableTimer.
+// jobInstance is a specific instance of Job.
//
-// Different instances are created each time CancellableTimer is Reset so each
-// timer has its own earlyReturn signal. This is to address a bug when a
-// CancellableTimer is stopped and reset in quick succession resulting in a
-// timer instance's earlyReturn signal being affected or seen by another timer
-// instance.
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
//
// Consider the following sceneario where timer instances share a common
// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
// (B), third (C), and fourth (D) instance of the timer firing, respectively):
// T1: Obtain L
-// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T1: Create a new Job w/ lock L (create instance A)
// T2: instance A fires, blocked trying to obtain L.
// T1: Attempt to stop instance A (set earlyReturn = true)
-// T1: Reset timer (create instance B)
+// T1: Schedule timer (create instance B)
// T3: instance B fires, blocked trying to obtain L.
// T1: Attempt to stop instance B (set earlyReturn = true)
-// T1: Reset timer (create instance C)
+// T1: Schedule timer (create instance C)
// T4: instance C fires, blocked trying to obtain L.
// T1: Attempt to stop instance C (set earlyReturn = true)
-// T1: Reset timer (create instance D)
+// T1: Schedule timer (create instance D)
// T5: instance D fires, blocked trying to obtain L.
// T1: Release L
//
-// Now that T1 has released L, any of the 4 timer instances can take L and check
-// earlyReturn. If the timers simply check earlyReturn and then do nothing
-// further, then instance D will never early return even though it was not
-// requested to stop. If the timers reset earlyReturn before early returning,
-// then all but one of the timers will do work when only one was expected to.
-// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
// will fire (again, when only one was expected to).
//
// To address the above concerns the simplest solution was to give each timer
// its own earlyReturn signal.
-type cancellableTimerInstance struct {
- timer *time.Timer
+type jobInstance struct {
+ timer Timer
// Used to inform the timer to early return when it gets stopped while the
// lock the timer tries to obtain when fired is held (T1 is a goroutine that
// tries to cancel the timer and T2 is the goroutine that handles the timer
// firing):
- // T1: Obtain the lock, then call StopLocked()
+ // T1: Obtain the lock, then call Cancel()
// T2: timer fires, and gets blocked on obtaining the lock
// T1: Releases lock
// T2: Obtains lock does unintended work
@@ -74,29 +73,33 @@ type cancellableTimerInstance struct {
earlyReturn *bool
}
-// stop stops the timer instance t from firing if it hasn't fired already. If it
+// stop stops the job instance j from firing if it hasn't fired already. If it
// has fired and is blocked at obtaining the lock, earlyReturn will be set to
// true so that it will early return when it obtains the lock.
-func (t *cancellableTimerInstance) stop() {
- if t.timer != nil {
- t.timer.Stop()
- *t.earlyReturn = true
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
}
}
-// CancellableTimer is a timer that does some work and can be safely cancelled
-// when it fires at the same time some "related work" is being done.
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
//
// The term "related work" is defined as some work that needs to be done while
// holding some lock that the timer must also hold while doing some work.
//
-// Note, it is not safe to copy a CancellableTimer as its timer instance creates
-// a closure over the address of the CancellableTimer.
-type CancellableTimer struct {
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
_ sync.NoCopy
+ // The clock used to schedule the backing timer
+ clock Clock
+
// The active instance of a cancellable timer.
- instance cancellableTimerInstance
+ instance jobInstance
// locker is the lock taken by the timer immediately after it fires and must
// be held when attempting to stop the timer.
@@ -113,59 +116,91 @@ type CancellableTimer struct {
fn func()
}
-// StopLocked prevents the Timer from firing if it has not fired already.
+// Cancel prevents the Job from executing if it has not executed already.
//
-// If the timer is blocked on obtaining the t.locker lock when StopLocked is
-// called, it will early return instead of calling t.fn.
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
//
// Note, t will be modified.
//
-// t.locker MUST be locked.
-func (t *CancellableTimer) StopLocked() {
- t.instance.stop()
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
// Nothing to do with the stopped instance anymore.
- t.instance = cancellableTimerInstance{}
+ j.instance = jobInstance{}
}
-// Reset changes the timer to expire after duration d.
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
//
-// Note, t will be modified.
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
//
-// Reset should only be called on stopped or expired timers. To be safe, callers
-// should always call StopLocked before calling Reset.
-func (t *CancellableTimer) Reset(d time.Duration) {
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
// Create a new instance.
earlyReturn := false
// Capture the locker so that updating the timer does not cause a data race
// when a timer fires and tries to obtain the lock (read the timer's locker).
- locker := t.locker
- t.instance = cancellableTimerInstance{
- timer: time.AfterFunc(d, func() {
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
locker.Lock()
defer locker.Unlock()
if earlyReturn {
// If we reach this point, it means that the timer fired while another
- // goroutine called StopLocked while it had the lock. Simply return
- // here and do nothing further.
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
earlyReturn = false
return
}
- t.fn()
+ j.fn()
}),
earlyReturn: &earlyReturn,
}
}
-// NewCancellableTimer returns an unscheduled CancellableTimer with the given
-// locker and fn.
-//
-// fn MUST NOT attempt to lock locker.
-//
-// Callers must call Reset to schedule the timer to fire.
-func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
- return &CancellableTimer{locker: locker, fn: fn}
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index b4940e397..a82384c49 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -28,8 +28,8 @@ const (
longDuration = 1 * time.Second
)
-func TestCancellableTimerReassignment(t *testing.T) {
- var timer tcpip.CancellableTimer
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- timer = *tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
wg.Done()
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
lock.Unlock()
}()
}
wg.Wait()
}
-func TestCancellableTimerFire(t *testing.T) {
+func TestJobExecution(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
ch <- struct{}{}
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
lock.Lock()
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
}
}
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
// Wait for timer to fire if it wasn't correctly stopped.
@@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
case <-time.After(middleDuration):
}
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
}
}
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
+func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
}
@@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) {
}
}
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
+ job.Schedule(middleDuration)
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration * 2)
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
}
@@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
@@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
}
}
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 678f4e016..4612be4e7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -797,7 +797,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 7b2083a09..0e46e6355 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -441,6 +441,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(hdr.SourceAddress()),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
@@ -448,34 +449,57 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(localAddr),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index c2e9fd29f..f85a68554 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -456,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -700,7 +700,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
}
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 81b740115..1798510bc 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 83dc10ed0..0f7487963 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1209,6 +1209,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1956,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -2546,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a14643ae8..6e692da07 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1451,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
index 921d1da9e..4c739c9e9 100644
--- a/pkg/test/dockerutil/exec.go
+++ b/pkg/test/dockerutil/exec.go
@@ -87,7 +87,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc
execid: resp.ID,
conn: hijack,
}, nil
-
}
func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig {