diff options
Diffstat (limited to 'pkg')
217 files changed, 9689 insertions, 1836 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 2b789c4ec..a4bb62013 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -72,6 +72,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go index 08bfde3b5..8138088a6 100644 --- a/pkg/abi/linux/futex.go +++ b/pkg/abi/linux/futex.go @@ -60,3 +60,21 @@ const ( FUTEX_WAITERS = 0x80000000 FUTEX_OWNER_DIED = 0x40000000 ) + +// FUTEX_BITSET_MATCH_ANY has all bits set. +const FUTEX_BITSET_MATCH_ANY = 0xffffffff + +// ROBUST_LIST_LIMIT protects against a deliberately circular list. +const ROBUST_LIST_LIMIT = 2048 + +// RobustListHead corresponds to Linux's struct robust_list_head. +// +// +marshal +type RobustListHead struct { + List uint64 + FutexOffset uint64 + ListOpPending uint64 +} + +// SizeOfRobustListHead is the size of a RobustListHead struct. +var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes() diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2062e6a4b..2c5e56ae5 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -67,10 +67,29 @@ const ( // ioctl(2) requests provided by uapi/linux/sockios.h const ( - SIOCGIFMEM = 0x891f - SIOCGIFPFLAGS = 0x8935 - SIOCGMIIPHY = 0x8947 - SIOCGMIIREG = 0x8948 + SIOCGIFNAME = 0x8910 + SIOCGIFCONF = 0x8912 + SIOCGIFFLAGS = 0x8913 + SIOCGIFADDR = 0x8915 + SIOCGIFDSTADDR = 0x8917 + SIOCGIFBRDADDR = 0x8919 + SIOCGIFNETMASK = 0x891b + SIOCGIFMETRIC = 0x891d + SIOCGIFMTU = 0x8921 + SIOCGIFMEM = 0x891f + SIOCGIFHWADDR = 0x8927 + SIOCGIFINDEX = 0x8933 + SIOCGIFPFLAGS = 0x8935 + SIOCGIFTXQLEN = 0x8942 + SIOCETHTOOL = 0x8946 + SIOCGMIIPHY = 0x8947 + SIOCGMIIREG = 0x8948 + SIOCGIFMAP = 0x8970 +) + +// ioctl(2) requests provided by uapi/asm-generic/sockios.h +const ( + SIOCGSTAMP = 0x8906 ) // ioctl(2) directions. Used to calculate requests number. diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..a91f9f018 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -76,6 +84,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +122,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +219,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +386,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,12 +500,6 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 @@ -392,6 +512,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 40bec566c..ceda0a8d3 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -187,6 +187,8 @@ const ( // Device types, from uapi/linux/if_arp.h. const ( + ARPHRD_NONE = 65534 + ARPHRD_ETHER = 1 ARPHRD_LOOPBACK = 772 ) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 4a14ef691..c24a8216e 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -134,6 +134,15 @@ const ( SHUT_RDWR = 2 ) +// Packet types from <linux/if_packet.h> +const ( + PACKET_HOST = 0 // To us + PACKET_BROADCAST = 1 // To all + PACKET_MULTICAST = 2 // To group + PACKET_OTHERHOST = 3 // To someone else + PACKET_OUTGOING = 4 // Outgoing of any type +) + // Socket options from socket.h. const ( SO_DEBUG = 1 @@ -225,6 +234,8 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. @@ -294,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -308,6 +321,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -405,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD new file mode 100644 index 000000000..eda82cfc1 --- /dev/null +++ b/pkg/iovec/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "iovec", + srcs = ["iovec.go"], + visibility = ["//:sandbox"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "iovec_test", + size = "small", + srcs = ["iovec_test.go"], + library = ":iovec", + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go new file mode 100644 index 000000000..dd70fe80f --- /dev/null +++ b/pkg/iovec/iovec.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package iovec provides helpers to interact with vectorized I/O on host +// system. +package iovec + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// MaxIovs is the maximum number of iovecs host platform can accept. +var MaxIovs = linux.UIO_MAXIOV + +// Builder is a builder for slice of syscall.Iovec. +type Builder struct { + iovec []syscall.Iovec + storage [8]syscall.Iovec + + // overflow tracks the last buffer when iovec length is at MaxIovs. + overflow []byte +} + +// Add adds buf to b preparing to be written. Zero-length buf won't be added. +func (b *Builder) Add(buf []byte) { + if len(buf) == 0 { + return + } + if b.iovec == nil { + b.iovec = b.storage[:0] + } + if len(b.iovec) >= MaxIovs { + b.addByAppend(buf) + return + } + b.iovec = append(b.iovec, syscall.Iovec{ + Base: &buf[0], + Len: uint64(len(buf)), + }) + // Keep the last buf if iovec is at max capacity. We will need to append to it + // for later bufs. + if len(b.iovec) == MaxIovs { + n := len(buf) + b.overflow = buf[:n:n] + } +} + +func (b *Builder) addByAppend(buf []byte) { + b.overflow = append(b.overflow, buf...) + b.iovec[len(b.iovec)-1] = syscall.Iovec{ + Base: &b.overflow[0], + Len: uint64(len(b.overflow)), + } +} + +// Build returns the final Iovec slice. The length of returned iovec will not +// excceed MaxIovs. +func (b *Builder) Build() []syscall.Iovec { + return b.iovec +} diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go new file mode 100644 index 000000000..a3900c299 --- /dev/null +++ b/pkg/iovec/iovec_test.go @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package iovec + +import ( + "bytes" + "fmt" + "syscall" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func TestBuilderEmpty(t *testing.T) { + var builder Builder + iovecs := builder.Build() + if got, want := len(iovecs), 0; got != want { + t.Errorf("len(iovecs) = %d, want %d", got, want) + } +} + +func TestBuilderBuild(t *testing.T) { + a := []byte{1, 2} + b := []byte{3, 4, 5} + + var builder Builder + builder.Add(a) + builder.Add(b) + builder.Add(nil) // Nil slice won't be added. + builder.Add([]byte{}) // Empty slice won't be added. + iovecs := builder.Build() + + if got, want := len(iovecs), 2; got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } + for i, data := range [][]byte{a, b} { + if got, want := *iovecs[i].Base, data[0]; got != want { + t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want) + } + if got, want := iovecs[i].Len, uint64(len(data)); got != want { + t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want) + } + } +} + +func TestBuilderBuildMaxIov(t *testing.T) { + for _, test := range []struct { + numIov int + }{ + { + numIov: MaxIovs - 1, + }, + { + numIov: MaxIovs, + }, + { + numIov: MaxIovs + 1, + }, + { + numIov: MaxIovs + 10, + }, + } { + name := fmt.Sprintf("numIov=%v", test.numIov) + t.Run(name, func(t *testing.T) { + var data []byte + var builder Builder + for i := 0; i < test.numIov; i++ { + buf := []byte{byte(i)} + builder.Add(buf) + data = append(data, buf...) + } + iovec := builder.Build() + + // Check the expected length of iovec. + wantNum := test.numIov + if wantNum > MaxIovs { + wantNum = MaxIovs + } + if got, want := len(iovec), wantNum; got != want { + t.Errorf("len(iovec) = %d, want %d", got, want) + } + + // Test a real read-write. + var fds [2]int + if err := unix.Pipe(fds[:]); err != nil { + t.Fatalf("Pipe: %v", err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if int(wrote) != len(data) || e != 0 { + t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data)) + } + + got := make([]byte, len(data)) + if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil { + t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got)) + } + + if !bytes.Equal(got, data) { + t.Errorf("read: got data %v, want %v", got, data) + } + }) + } +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 57b89ad7d..2cb59f934 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -2506,7 +2506,7 @@ type msgFactory struct { var msgRegistry registry type registry struct { - factories [math.MaxUint8]msgFactory + factories [math.MaxUint8 + 1]msgFactory // largestFixedSize is computed so that given some message size M, you can // compute the maximum payload size (e.g. for Twrite, Rread) with diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index daba8b172..fd95eb2d2 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -28,7 +28,14 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs + + // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register. + TPIDR_EL0 uint64 +} const ( // SyscallWidth is the width of insturctions. @@ -101,9 +108,6 @@ type State struct { // Our floating point state. aarch64FPState `state:"wait"` - // TLS pointer - TPValue uint64 - // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -157,7 +161,6 @@ func (s *State) Fork() State { return State{ Regs: s.Regs, aarch64FPState: s.aarch64FPState.fork(), - TPValue: s.TPValue, FeatureSet: s.FeatureSet, OrigR0: s.OrigR0, } @@ -241,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers { return s.Regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } regs.UnmarshalUnsafe(buf) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // PtraceGetFPRegs implements Context.PtraceGetFPRegs. @@ -278,7 +281,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -291,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 3b3a0a272..1c3e3c14c 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and // u_debugreg, returning 0 or silently no-oping for other fields // respectively. - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) @@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { if addr&7 != 0 || addr >= userStructSize { return syscall.EIO } - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index ada7ac7b8..cabbf60e0 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) { // TLS returns the current TLS pointer. func (c *context64) TLS() uintptr { - return uintptr(c.TPValue) + return uintptr(c.Regs.TPIDR_EL0) } // SetTLS sets the current TLS pointer. Returns false if value is invalid. @@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool { return false } - c.TPValue = uint64(value) + c.Regs.TPIDR_EL0 = uint64(value) return true } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index dc458b37f..b9405b320 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -31,7 +31,11 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs +} // System-related constants for x86. const ( @@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers { return regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } @@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) { } regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // isUserSegmentSelector returns true if the given segment selector specifies a @@ -543,7 +547,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index aabce6cc9..d41d23a43 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", + "//pkg/iovec", "//pkg/log", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index 5c18dbd5e..905afb50d 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -17,15 +17,12 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -76,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 69879498a..1081fff52 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -67,8 +67,8 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf } // Stat implements kernfs.Inode.Stat. -func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := mi.InodeAttrs.Stat(vfsfs, opts) +func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -186,7 +186,7 @@ func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatO // Stat implements vfs.FileDescriptionImpl.Stat. func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem() - return mfd.inode.Stat(fs, opts) + return mfd.inode.Stat(ctx, fs, opts) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index cf1a0f0ac..a91cae3ef 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -73,8 +73,8 @@ func (si *slaveInode) Valid(context.Context) bool { } // Stat implements kernfs.Inode.Stat. -func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := si.InodeAttrs.Stat(vfsfs, opts) +func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -132,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen return sfd.inode.t.ld.outputQueueWrite(ctx, src) } -// Ioctl implements vfs.FileDescripionImpl.Ioctl. +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ @@ -183,7 +183,7 @@ func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOp // Stat implements vfs.FileDescriptionImpl.Stat. func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.Stat(fs, opts) + return sfd.inode.Stat(ctx, fs, opts) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index ef24f8159..abc610ef3 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -96,7 +96,7 @@ go_test( "//pkg/syserror", "//pkg/test/testutil", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 41567967d..737007748 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -6,12 +6,17 @@ go_library( name = "fuse", srcs = [ "dev.go", + "fusefs.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index f6a67d005..c9e12a94f 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -30,6 +31,10 @@ type fuseDevice struct{} // Open implements vfs.Device.Open. func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + if !kernel.FUSEEnabled { + return nil, syserror.ENOENT + } + var fd DeviceFD if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ UseDentryMetadata: true, @@ -46,6 +51,9 @@ type DeviceFD struct { vfs.DentryMetadataFileDescriptionImpl vfs.NoLockFD + // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD. + mounted bool + // TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue // and deque requests, control synchronization and establish communication // between the FUSE kernel module and the /dev/fuse character device. @@ -56,26 +64,51 @@ func (fd *DeviceFD) Release() {} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } // Read implements vfs.FileDescriptionImpl.Read. func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } // Write implements vfs.FileDescriptionImpl.Write. func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go new file mode 100644 index 000000000..f7775fb9b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -0,0 +1,200 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fuse implements fusefs. +package fuse + +import ( + "strconv" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Name is the default filesystem name. +const Name = "fuse" + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +type filesystemOptions struct { + // userID specifies the numeric uid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + userID uint32 + + // groupID specifies the numeric gid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + groupID uint32 + + // rootMode specifies the the file mode of the filesystem's root. + rootMode linux.FileMode +} + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // fuseFD is the FD returned when opening /dev/fuse. It is used for communication + // between the FUSE server daemon and the sentry fusefs. + fuseFD *DeviceFD + + // opts is the options the fusefs is initialized with. + opts filesystemOptions +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + var fsopts filesystemOptions + mopts := vfs.GenericParseMountOptions(opts.Data) + deviceDescriptorStr, ok := mopts["fd"] + if !ok { + log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name()) + return nil, nil, syserror.EINVAL + } + delete(mopts, "fd") + + deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) + if err != nil { + return nil, nil, err + } + + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) + return nil, nil, syserror.EINVAL + } + fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + + // Parse and set all the other supported FUSE mount options. + // TODO: Expand the supported mount options. + if userIDStr, ok := mopts["user_id"]; ok { + delete(mopts, "user_id") + userID, err := strconv.ParseUint(userIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.userID = uint32(userID) + } + + if groupIDStr, ok := mopts["group_id"]; ok { + delete(mopts, "group_id") + groupID, err := strconv.ParseUint(groupIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.groupID = uint32(groupID) + } + + rootMode := linux.FileMode(0777) + modeStr, ok := mopts["rootmode"] + if ok { + delete(mopts, "rootmode") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) + return nil, nil, syserror.EINVAL + } + rootMode = linux.FileMode(mode) + } + fsopts.rootMode = rootMode + + // Check for unparsed options. + if len(mopts) != 0 { + log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts) + return nil, nil, syserror.EINVAL + } + + // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to + // mount a FUSE filesystem. + fuseFD := fuseFd.Impl().(*DeviceFD) + fuseFD.mounted = true + + fs := &filesystem{ + devMinor: devMinor, + fuseFD: fuseFD, + opts: fsopts, + } + + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + // TODO: dispatch a FUSE_INIT request to the FUSE daemon server before + // returning. Mount will not block on this dispatched request. + + // root is the fusefs root directory. + root := fs.newInode(creds, fsopts.rootMode) + + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + +// Inode implements kernfs.Inode. +type Inode struct { + kernfs.InodeAttrs + kernfs.InodeNoDynamicLookup + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.OrderedChildren + + locks vfs.FileLocks + + dentry kernfs.Dentry +} + +func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { + i := &Inode{} + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.dentry.Init(i) + + return &i.dentry +} + +// Open implements kernfs.Inode.Open. +func (i *Inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 5d83fe363..8c7c8e1b3 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -85,6 +85,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { d2 := &dentry{ refs: 1, // held by d fs: d.fs, + ino: d.fs.nextSyntheticIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -184,13 +185,13 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { { Name: ".", Type: linux.DT_DIR, - Ino: d.ino, + Ino: uint64(d.ino), NextOff: 1, }, { Name: "..", Type: uint8(atomic.LoadUint32(&parent.mode) >> 12), - Ino: parent.ino, + Ino: uint64(parent.ino), NextOff: 2, }, } @@ -226,7 +227,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: p9d.QID.Path, + Ino: uint64(inoFromPath(p9d.QID.Path)), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or @@ -259,7 +260,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { dirents = append(dirents, vfs.Dirent{ Name: child.name, Type: uint8(atomic.LoadUint32(&child.mode) >> 12), - Ino: child.ino, + Ino: uint64(child.ino), NextOff: int64(len(dirents) + 1), }) } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 7bcc99b29..00e3c99cd 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -150,11 +150,9 @@ afterSymlink: return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { - _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { + if err := d.parent.updateFromGetattr(ctx); err != nil { return nil, err } - d.parent.updateFromP9Attrs(attrMask, &attr) } rp.Advance() return d.parent, nil @@ -209,18 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil // Preconditions: As for getChildLocked. !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { + if child != nil { + // Need to lock child.metadataMu because we might be updating child + // metadata. We need to hold the lock *before* getting metadata from the + // server and release it after updating local metadata. + child.metadataMu.Lock() + } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil && err != syserror.ENOENT { + if child != nil { + child.metadataMu.Unlock() + } return nil, err } if child != nil { - if !file.isNil() && qid.Path == child.ino { - // The file at this path hasn't changed. Just update cached - // metadata. + if !file.isNil() && inoFromPath(qid.Path) == child.ino { + // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) + child.updateFromP9AttrsLocked(attrMask, &attr) + child.metadataMu.Unlock() return child, nil } + child.metadataMu.Unlock() if file.isNil() && child.isSynthetic() { // We have a synthetic file, and no remote file has arisen to // replace it. @@ -1326,7 +1334,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.renameMuRUnlockAndCheckCaching(&ds) return err } - if err := d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()); err != nil { + if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil { fs.renameMuRUnlockAndCheckCaching(&ds) return err } @@ -1499,3 +1507,7 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +func (fs *filesystem) nextSyntheticIno() inodeNumber { + return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) +} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 8e74e60a5..e20de84b5 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -110,6 +110,26 @@ type filesystem struct { syncMu sync.Mutex syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} + + // syntheticSeq stores a counter to used to generate unique inodeNumber for + // synthetic dentries. + syntheticSeq uint64 +} + +// inodeNumber represents inode number reported in Dirent.Ino. For regular +// dentries, it comes from QID.Path from the 9P server. Synthetic dentries +// have have their inodeNumber generated sequentially, with the MSB reserved to +// prevent conflicts with regular dentries. +type inodeNumber uint64 + +// Reserve MSB for synthetic mounts. +const syntheticInoMask = uint64(1) << 63 + +func inoFromPath(path uint64) inodeNumber { + if path&syntheticInoMask != 0 { + log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) + } + return inodeNumber(path &^ syntheticInoMask) } type filesystemOptions struct { @@ -582,21 +602,27 @@ type dentry struct { // returned by the server. dirents is protected by dirMu. dirents []vfs.Dirent - // Cached metadata; protected by metadataMu and accessed using atomic - // memory operations unless otherwise specified. + // Cached metadata; protected by metadataMu. + // To access: + // - In situations where consistency is not required (like stat), these + // can be accessed using atomic operations only (without locking). + // - Lock metadataMu and can access without atomic operations. + // To mutate: + // - Lock metadataMu and use atomic operations to update because we might + // have atomic readers that don't hold the lock. metadataMu sync.Mutex - ino uint64 // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + ino inodeNumber // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 ctime int64 btime int64 // File size, protected by both metadataMu and dataMu (i.e. both must be - // locked to mutate it). + // locked to mutate it; locking either is sufficient to access it). size uint64 // nlink counts the number of hard links to this dentry. It's updated and @@ -704,7 +730,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, file: file, - ino: qid.Path, + ino: inoFromPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -759,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool { // updateFromP9Attrs is called to update d's metadata after an update from the // remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() +// Precondition: d.metadataMu must be locked. +func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if mask.Mode { if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { d.metadataMu.Unlock() @@ -796,7 +822,6 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { if mask.Size { d.updateFileSizeLocked(attr.Size) } - d.metadataMu.Unlock() } // Preconditions: !d.isSynthetic() @@ -808,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { file p9file handleMuRLocked bool ) + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() d.handleMu.RLock() if !d.handle.file.isNil() { file = d.handle.file @@ -823,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { if err != nil { return err } - d.updateFromP9Attrs(attrMask, &attr) + d.updateFromP9AttrsLocked(attrMask, &attr) return nil } @@ -846,7 +875,7 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.UID = atomic.LoadUint32(&d.uid) stat.GID = atomic.LoadUint32(&d.gid) stat.Mode = uint16(atomic.LoadUint32(&d.mode)) - stat.Ino = d.ino + stat.Ino = uint64(d.ino) stat.Size = atomic.LoadUint64(&d.size) // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. @@ -859,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.DevMinor = d.fs.devMinor } -func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error { +func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -867,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } if err := mnt.CheckBeginWrite(); err != nil { @@ -884,14 +914,14 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // Prepare for truncate. if stat.Mask&linux.STATX_SIZE != 0 { - switch d.mode & linux.S_IFMT { - case linux.S_IFREG: + switch mode.FileType() { + case linux.ModeRegular: if !setLocalMtime { // Truncate updates mtime. setLocalMtime = true stat.Mtime.Nsec = linux.UTIME_NOW } - case linux.S_IFDIR: + case linux.ModeDirectory: return syserror.EISDIR default: return syserror.EINVAL @@ -908,6 +938,17 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } if !d.isSynthetic() { if stat.Mask != 0 { + if stat.Mask&linux.STATX_SIZE != 0 { + // Check whether to allow a truncate request to be made. + switch d.mode & linux.S_IFMT { + case linux.S_IFREG: + // Allow. + case linux.S_IFDIR: + return syserror.EISDIR + default: + return syserror.EINVAL + } + } if err := d.file.setAttr(ctx, p9.SetAttrMask{ Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, @@ -974,7 +1015,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin func (d *dentry) updateFileSizeLocked(newSize uint64) { d.dataMu.Lock() oldSize := d.size - d.size = newSize + atomic.StoreUint64(&d.size, newSize) // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings // below. This allows concurrent calls to Read/Translate/etc. These // functions synchronize with truncation by refusing to use cache @@ -1320,8 +1361,8 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name // Extended attributes in the user.* namespace are only supported for regular // files and directories. func (d *dentry) userXattrSupported() bool { - filetype := linux.S_IFMT & atomic.LoadUint32(&d.mode) - return filetype == linux.S_IFREG || filetype == linux.S_IFDIR + filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType() + return filetype == linux.ModeRegular || filetype == linux.ModeDirectory } // Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir(). @@ -1469,7 +1510,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()); err != nil { + if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil { return err } if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 724a3f1f7..8792ca4f2 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -126,11 +126,16 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o } func (h *handle) sync(ctx context.Context) error { + // Handle most common case first. if h.fd >= 0 { ctx.UninterruptibleSleepStart(false) err := syscall.Fsync(int(h.fd)) ctx.UninterruptibleSleepFinish(false) return err } + if h.file.isNil() { + // File hasn't been touched, there is nothing to sync. + return nil + } return h.file.fsync(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 3d2d3530a..02317a133 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -89,7 +89,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint if err != nil { return err } - d.size = size + d.dataMu.Lock() + atomic.StoreUint64(&d.size, size) + d.dataMu.Unlock() if !d.cachedMetadataAuthoritative() { d.touchCMtimeLocked() } @@ -153,26 +155,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } // Check that flags are supported. // // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. if opts.Flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP + return 0, offset, syserror.EOPNOTSUPP } + d := fd.dentry() + // If the fd was opened with O_APPEND, make sure the file size is updated. + // There is a possible race here if size is modified externally after + // metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } + } + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + // Set offset to file size if the fd was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) + } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) + n, err := fd.pwriteLocked(ctx, src, offset, opts) + return n, offset + n, err +} +// Preconditions: fd.dentry().metatdataMu must be locked. +func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() if d.fs.opts.interop != InteropModeShared { // Compare Linux's mm/filemap.c:__generic_file_write_iter() => // file_update_time(). This is d.touchCMtime(), but without locking @@ -235,8 +264,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -582,20 +611,19 @@ func (fd *regularFileFD) Sync(ctx context.Context) error { func (d *dentry) syncSharedHandle(ctx context.Context) error { d.handleMu.RLock() - if !d.handleWritable { - d.handleMu.RUnlock() - return nil - } - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) - d.dataMu.Unlock() - if err == nil { - // Sync the remote file. - err = d.handle.sync(ctx) + defer d.handleMu.RUnlock() + + if d.handleWritable { + d.dataMu.Lock() + // Write dirty cached data to the remote file. + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } } - d.handleMu.RUnlock() - return err + // Sync the remote file. + return d.handle.sync(ctx) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 3c4e7e2e4..811528982 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -16,6 +16,7 @@ package gofer import ( "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -28,9 +29,9 @@ import ( ) // specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device -// special files, and (when filesystemOptions.specialRegularFiles is in effect) -// regular files. specialFileFD differs from regularFileFD by using per-FD -// handles instead of shared per-dentry handles, and never buffering I/O. +// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is +// in effect) regular files. specialFileFD differs from regularFileFD by using +// per-FD handles instead of shared per-dentry handles, and never buffering I/O. type specialFileFD struct { fileDescription @@ -41,10 +42,10 @@ type specialFileFD struct { // file offset is significant, i.e. a regular file. seekable is immutable. seekable bool - // mayBlock is true if this file description represents a file for which - // queue may send I/O readiness events. mayBlock is immutable. - mayBlock bool - queue waiter.Queue + // haveQueue is true if this file description represents a file for which + // queue may send I/O readiness events. haveQueue is immutable. + haveQueue bool + queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex @@ -54,14 +55,14 @@ type specialFileFD struct { func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG - mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK + haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 fd := &specialFileFD{ - handle: h, - seekable: seekable, - mayBlock: mayBlock, + handle: h, + seekable: seekable, + haveQueue: haveQueue, } fd.LockFD.Init(locks) - if mayBlock && h.fd >= 0 { + if haveQueue { if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { return nil, err } @@ -70,7 +71,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, DenyPRead: !seekable, DenyPWrite: !seekable, }); err != nil { - if mayBlock && h.fd >= 0 { + if haveQueue { fdnotifier.RemoveFD(h.fd) } return nil, err @@ -80,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, // Release implements vfs.FileDescriptionImpl.Release. func (fd *specialFileFD) Release() { - if fd.mayBlock && fd.handle.fd >= 0 { + if fd.haveQueue { fdnotifier.RemoveFD(fd.handle.fd) } fd.handle.close(context.Background()) @@ -100,7 +101,7 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { // Readiness implements waiter.Waitable.Readiness. func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { - if fd.mayBlock { + if fd.haveQueue { return fdnotifier.NonBlockingPoll(fd.handle.fd, mask) } return fd.fileDescription.Readiness(mask) @@ -108,8 +109,9 @@ func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { // EventRegister implements waiter.Waitable.EventRegister. func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - if fd.mayBlock { + if fd.haveQueue { fd.queue.EventRegister(e, mask) + fdnotifier.UpdateFD(fd.handle.fd) return } fd.fileDescription.EventRegister(e, mask) @@ -117,8 +119,9 @@ func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { // EventUnregister implements waiter.Waitable.EventUnregister. func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { - if fd.mayBlock { + if fd.haveQueue { fd.queue.EventUnregister(e) + fdnotifier.UpdateFD(fd.handle.fd) return } fd.fileDescription.EventUnregister(e) @@ -142,7 +145,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't // hold here since specialFileFD doesn't client-cache data. Just buffer the // read instead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } buf := make([]byte, dst.NumBytes()) @@ -174,39 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if fd.seekable && offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } // Check that flags are supported. // // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. if opts.Flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP + return 0, offset, syserror.EOPNOTSUPP + } + + d := fd.dentry() + // If the regular file fd was opened with O_APPEND, make sure the file size + // is updated. There is a possible race here if size is modified externally + // after metadata cache is updated. + if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } if fd.seekable { + // We need to hold the metadataMu *while* writing to a regular file. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Set offset to file size if the regular file was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) + } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) } // Do a buffered write. See rationale in PRead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d.cachedMetadataAuthoritative() { d.touchCMtime() } buf := make([]byte, src.NumBytes()) // Don't do partial writes if we get a partial read from src. if _, err := src.CopyIn(ctx, buf); err != nil { - return 0, err + return 0, offset, err } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } - return int64(n), err + finalOff = offset + // Update file size for regular files. + if fd.seekable { + finalOff += int64(n) + // d.metadataMu is already locked at this point. + if uint64(finalOff) > d.size { + d.dataMu.Lock() + defer d.dataMu.Unlock() + atomic.StoreUint64(&d.size, uint64(finalOff)) + } + } + return int64(n), finalOff, err } // Write implements vfs.FileDescriptionImpl.Write. @@ -216,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts } fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 44a09d87a..e86fbe2d5 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -22,6 +22,7 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/fspath", + "//pkg/iovec", "//pkg/log", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 1cd2982cb..c894f2ca0 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -259,7 +259,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL } @@ -373,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { // SetStat implements kernfs.Inode. func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - s := opts.Stat + s := &opts.Stat m := s.Mask if m == 0 { @@ -386,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Fstat(i.hostFD, &hostStat); err != nil { return err } - if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { return err } @@ -396,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } } if m&linux.STATX_SIZE != 0 { + if hostStat.Mode&linux.S_IFMT != linux.S_IFREG { + return syserror.EINVAL + } if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } @@ -534,8 +537,8 @@ func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // Stat implements vfs.FileDescriptionImpl. -func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { - return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts) +func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts) } // Release implements vfs.FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go index 584c247d2..fc0d5fd38 100644 --- a/pkg/sentry/fsimpl/host/socket_iovec.go +++ b/pkg/sentry/fsimpl/host/socket_iovec.go @@ -17,13 +17,10 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -74,7 +71,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 179df6c1e..3835557fe 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -70,6 +70,6 @@ go_test( "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 6886b0876..c6c4472e7 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -127,7 +127,7 @@ func (fd *DynamicBytesFD) Release() {} // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(fs, opts) + return fd.inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index ca8b8c63b..1d37ccb98 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -112,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) } -// Release implements vfs.FileDecriptionImpl.Release. +// Release implements vfs.FileDescriptionImpl.Release. func (fd *GenericDirectoryFD) Release() {} func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { @@ -123,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode { return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode } -// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds // o.mu when calling cb. func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { fd.mu.Lock() @@ -132,7 +132,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent opts := vfs.StatOptions{Mask: linux.STATX_INO} // Handle ".". if fd.off == 0 { - stat, err := fd.inode().Stat(fd.filesystem(), opts) + stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -152,7 +152,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent if fd.off == 1 { vfsd := fd.vfsfd.VirtualDentry().Dentry() parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode - stat, err := parentInode.Stat(fd.filesystem(), opts) + stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -176,7 +176,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent childIdx := fd.off - 2 for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { inode := it.Dentry.Impl().(*Dentry).inode - stat, err := inode.Stat(fd.filesystem(), opts) + stat, err := inode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -198,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent return err } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() defer fd.mu.Unlock() @@ -226,7 +226,7 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.filesystem() inode := fd.inode() - return inode.Stat(fs, opts) + return inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 8939871c1..61a36cff9 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -684,7 +684,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } - return inode.Stat(fs.VFSFilesystem(), opts) + return inode.Stat(ctx, fs.VFSFilesystem(), opts) } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 4cb885d87..579e627f0 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -243,7 +243,7 @@ func (a *InodeAttrs) Mode() linux.FileMode { // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. -func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { +func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK stat.DevMajor = a.devMajor @@ -267,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 596de1edf..46f207664 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -346,7 +346,7 @@ type inodeMetadata interface { // Stat returns the metadata for this inode. This corresponds to // vfs.FilesystemImpl.StatAt. - Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) + Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) // SetStat updates the metadata for this inode. This corresponds to // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index ff82e1f20..6b705e955 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1104,7 +1104,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := rp.Mount() diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index a3c1f7a8d..c0749e711 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -151,7 +151,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := fd.vfsfd.Mount() @@ -176,7 +176,7 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) return nil } -// StatFS implements vfs.FileDesciptionImpl.StatFS. +// StatFS implements vfs.FileDescriptionImpl.StatFS. func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.filesystem().statFS(ctx) } diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index dd7eaf4a8..811f80a5f 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -115,7 +115,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode.Stat. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) return linux.Statx{ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index 36a89540c..79c2725f3 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -128,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac return fd.GenericDirectoryFD.IterDirents(ctx, cb) } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { if fd.task.ExitState() >= kernel.TaskExitZombie { return 0, syserror.ENOENT @@ -165,8 +165,8 @@ func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *v } // Stat implements kernfs.Inode. -func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 8bb2b0ce1..a5c7aa470 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -156,8 +156,8 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. } // Stat implements kernfs.Inode. -func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.Inode.Stat(fs, opts) +func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.Inode.Stat(ctx, fs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 9af43b859..859b7d727 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -876,7 +876,7 @@ var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil) // Stat implements FileDescriptionImpl. func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(vfs, opts) + return fd.inode.Stat(ctx, vfs, opts) } // SetStat implements FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 2f214d0c2..6d2b90a8b 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -206,8 +206,8 @@ func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs. return fd.VFSFileDescription(), nil } -func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index a741e2bb6..1b548ccd4 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -29,6 +29,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 0e4053a46..400a97996 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -32,6 +32,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index c16a36cdb..e743e8114 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -62,6 +62,7 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } + kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, } @@ -73,7 +74,7 @@ func Boot() (*kernel.Kernel, error) { k.SetMemoryFile(mf) // Pass k as the platform since it is savable, unlike the actual platform. - vdso, err := loader.PrepareVDSO(nil, k) + vdso, err := loader.PrepareVDSO(k) if err != nil { return nil, fmt.Errorf("creating vdso: %v", err) } @@ -103,11 +104,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("initializing kernel: %v", err) } - kernel.VFS2Enabled = true - - if err := k.VFS().Init(); err != nil { - return nil, fmt.Errorf("VFS init: %v", err) - } k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, AllowUserList: true, diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index ed40f6b52..ef210a69b 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -277,7 +277,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v creds := rp.Credentials() var childInode *inode switch opts.Mode.FileType() { - case 0, linux.S_IFREG: + case linux.S_IFREG: childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) case linux.S_IFIFO: childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) @@ -649,7 +649,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.mu.RUnlock() return err } - if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil { fs.mu.RUnlock() return err } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 1cdb46e6f..abbaa5d60 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -325,8 +325,15 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since @@ -334,40 +341,44 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { - return 0, syserror.EOPNOTSUPP + return 0, offset, syserror.EOPNOTSUPP } srclen := src.NumBytes() if srclen == 0 { - return 0, nil + return 0, offset, nil } f := fd.inode().impl.(*regularFile) + f.inode.mu.Lock() + defer f.inode.mu.Unlock() + // If the file is opened with O_APPEND, update offset to file size. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking f.inode.mu is sufficient for reading f.size. + offset = int64(f.size) + } if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - var err error srclen, err = vfs.CheckLimit(ctx, offset, srclen) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(srclen) - f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) - fd.inode().touchCMtimeLocked() - f.inode.mu.Unlock() + f.inode.touchCMtimeLocked() putRegularFileReadWriter(rw) - return n, err + return n, n + offset, err } // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.offMu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.offMu.Unlock() return n, err } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index d7f4f0779..2545d88e9 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -452,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) { } } -func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error { +func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -460,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&i.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { return err } i.mu.Lock() @@ -695,7 +696,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) d := fd.dentry() - if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, creds, &opts); err != nil { return err } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 25fe1921b..f6886a758 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -132,6 +132,7 @@ go_library( "task_stop.go", "task_syscall.go", "task_usermem.go", + "task_work.go", "thread_group.go", "threads.go", "timekeeper.go", diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 732e66da4..bcc1b29a8 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3 } } -// UnlockPI unlock the futex following the Priority-inheritance futex -// rules. The address provided must contain the caller's TID. If there are -// waiters, TID of the next waiter (FIFO) is set to the given address, and the -// waiter woken up. If there are no waiters, 0 is set to the address. +// UnlockPI unlocks the futex following the Priority-inheritance futex rules. +// The address provided must contain the caller's TID. If there are waiters, +// TID of the next waiter (FIFO) is set to the given address, and the waiter +// woken up. If there are no waiters, 0 is set to the address. func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error { k, err := getKey(t, addr, private) if err != nil { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 2177b785a..15dae0f5b 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -81,6 +81,10 @@ import ( // easy access everywhere. To be removed once VFS2 becomes the default. var VFS2Enabled = false +// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow +// easy access everywhere. To be removed once FUSE is completed. +var FUSEEnabled = false + // Kernel represents an emulated Linux kernel. It must be initialized by calling // Init() or LoadFrom(). // @@ -1465,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 { return now } +// AfterFunc implements tcpip.Clock.AfterFunc. +func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(k.realtimeClock, d, f) +} + // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 4607cde2f..a83ce219c 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -98,6 +98,15 @@ func (s *syslog) Log() []byte { s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...) } + if VFS2Enabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...) + if FUSEEnabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...) + } + } + time += rand.Float64() / 2 s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index f48247c94..c4db05bd8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -68,6 +68,21 @@ type Task struct { // runState is exclusive to the task goroutine. runState taskRunState + // taskWorkCount represents the current size of the task work queue. It is + // used to avoid acquiring taskWorkMu when the queue is empty. + // + // Must accessed with atomic memory operations. + taskWorkCount int32 + + // taskWorkMu protects taskWork. + taskWorkMu sync.Mutex `state:"nosave"` + + // taskWork is a queue of work to be executed before resuming user execution. + // It is similar to the task_work mechanism in Linux. + // + // taskWork is exclusive to the task goroutine. + taskWork []TaskWorker + // haveSyscallReturn is true if tc.Arch().Return() represents a value // returned by a syscall (or set by ptrace after a syscall). // @@ -550,6 +565,10 @@ type Task struct { // futexWaiter is exclusive to the task goroutine. futexWaiter *futex.Waiter `state:"nosave"` + // robustList is a pointer to the head of the tasks's robust futex + // list. + robustList usermem.Addr + // startTime is the real time at which the task started. It is set when // a Task is created or invokes execve(2). // diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 9b69f3cbe..7803b98d0 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -207,6 +207,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { return flags.CloseOnExec }) + // Handle the robust futex list. + t.exitRobustList() + // NOTE(b/30815691): We currently do not implement privileged // executables (set-user/group-ID bits and file capabilities). This // allows us to unconditionally enable user dumpability on the new mm. diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c4ade6e8e..231ac548a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState { } } + // Handle the robust futex list. + t.exitRobustList() + // Deactivate the address space and update max RSS before releasing the // task's MM. t.Deactivate() diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index a53e77c9f..4b535c949 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -15,6 +15,7 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" ) @@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) { func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) { return t.MemoryManager().GetSharedFutexKey(t, addr) } + +// GetRobustList sets the robust futex list for the task. +func (t *Task) GetRobustList() usermem.Addr { + t.mu.Lock() + addr := t.robustList + t.mu.Unlock() + return addr +} + +// SetRobustList sets the robust futex list for the task. +func (t *Task) SetRobustList(addr usermem.Addr) { + t.mu.Lock() + t.robustList = addr + t.mu.Unlock() +} + +// exitRobustList walks the robust futex list, marking locks dead and notifying +// wakers. It corresponds to Linux's exit_robust_list(). Following Linux, +// errors are silently ignored. +func (t *Task) exitRobustList() { + t.mu.Lock() + addr := t.robustList + t.robustList = 0 + t.mu.Unlock() + + if addr == 0 { + return + } + + var rl linux.RobustListHead + if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil { + return + } + + next := rl.List + done := 0 + var pendingLockAddr usermem.Addr + if rl.ListOpPending != 0 { + pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset) + } + + // Wake up normal elements. + for usermem.Addr(next) != addr { + // We traverse to the next element of the list before we + // actually wake anything. This prevents the race where waking + // this futex causes a modification of the list. + thisLockAddr := usermem.Addr(next + rl.FutexOffset) + + // Try to decode the next element in the list before waking the + // current futex. But don't check the error until after we've + // woken the current futex. Linux does it in this order too + _, nextErr := t.CopyIn(usermem.Addr(next), &next) + + // Wakeup the current futex if it's not pending. + if thisLockAddr != pendingLockAddr { + t.wakeRobustListOne(thisLockAddr) + } + + // If there was an error copying the next futex, we must bail. + if nextErr != nil { + break + } + + // This is a user structure, so it could be a massive list, or + // even contain a loop if they are trying to mess with us. We + // cap traversal to prevent that. + done++ + if done >= linux.ROBUST_LIST_LIMIT { + break + } + } + + // Is there a pending entry to wake? + if pendingLockAddr != 0 { + t.wakeRobustListOne(pendingLockAddr) + } +} + +// wakeRobustListOne wakes a single futex from the robust list. +func (t *Task) wakeRobustListOne(addr usermem.Addr) { + // Bit 0 in address signals PI futex. + pi := addr&1 == 1 + addr = addr &^ 1 + + // Load the futex. + f, err := t.LoadUint32(addr) + if err != nil { + // Can't read this single value? Ignore the problem. + // We can wake the other futexes in the list. + return + } + + tid := uint32(t.ThreadID()) + for { + // Is this held by someone else? + if f&linux.FUTEX_TID_MASK != tid { + return + } + + // This thread is dying and it's holding this futex. We need to + // set the owner died bit and wake up any waiters. + newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED + if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil { + return + } else if curF != f { + // Futex changed out from under us. Try again... + f = curF + continue + } + + // Wake waiters if there are any. + if f&linux.FUTEX_WAITERS != 0 { + private := f&linux.FUTEX_PRIVATE_FLAG != 0 + if pi { + t.Futex().UnlockPI(t, addr, tid, private) + return + } + t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1) + } + + // Done. + return + } +} diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index d654dd997..7d4f44caf 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState { return (*runInterrupt)(nil) } - // We're about to switch to the application again. If there's still a + // Execute any task work callbacks before returning to user space. + if atomic.LoadInt32(&t.taskWorkCount) > 0 { + t.taskWorkMu.Lock() + queue := t.taskWork + t.taskWork = nil + atomic.StoreInt32(&t.taskWorkCount, 0) + t.taskWorkMu.Unlock() + + // Do not hold taskWorkMu while executing task work, which may register + // more work. + for _, work := range queue { + work.TaskWork(t) + } + } + + // We're about to switch to the application again. If there's still an // unhandled SyscallRestartErrno that wasn't translated to an EINTR, // restart the syscall that was interrupted. If there's a saved signal // mask, restore it. (Note that restoring the saved signal mask may unblock diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go new file mode 100644 index 000000000..dda5a433a --- /dev/null +++ b/pkg/sentry/kernel/task_work.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync/atomic" + +// TaskWorker is a deferred task. +// +// This must be savable. +type TaskWorker interface { + // TaskWork will be executed prior to returning to user space. Note that + // TaskWork may call RegisterWork again, but this will not be executed until + // the next return to user space, unlike in Linux. This effectively allows + // registration of indefinite user return hooks, but not by default. + TaskWork(t *Task) +} + +// RegisterWork can be used to register additional task work that will be +// performed prior to returning to user space. See TaskWorker.TaskWork for +// semantics regarding registration. +func (t *Task) RegisterWork(work TaskWorker) { + t.taskWorkMu.Lock() + defer t.taskWorkMu.Unlock() + atomic.AddInt32(&t.taskWorkCount, 1) + t.taskWork = append(t.taskWork, work) +} diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 7ba7dc50c..2817aa3ba 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -6,6 +6,7 @@ go_library( name = "time", srcs = [ "context.go", + "tcpip.go", "time.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go new file mode 100644 index 000000000..c4474c0cf --- /dev/null +++ b/pkg/sentry/kernel/time/tcpip.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package time + +import ( + "sync" + "time" +) + +// TcpipAfterFunc waits for duration to elapse according to clock then runs fn. +// The timer is started immediately and will fire exactly once. +func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer { + timer := &TcpipTimer{ + clock: clock, + } + timer.notifier = functionNotifier{ + fn: func() { + // tcpip.Timer.Stop() explicitly states that the function is called in a + // separate goroutine that Stop() does not synchronize with. + // Timer.Destroy() synchronizes with calls to TimerListener.Notify(). + // This is semantically meaningful because, in the former case, it's + // legal to call tcpip.Timer.Stop() while holding locks that may also be + // taken by the function, but this isn't so in the latter case. Most + // immediately, Timer calls TimerListener.Notify() while holding + // Timer.mu. A deadlock occurs without spawning a goroutine: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock! + // + // Spawning a goroutine avoids the deadlock: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() <- Launches T2 + // T2: + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, blocks + // T1: + // => (returns) <- Timer.mu.Unlock() called + // T2: + // => (continues) <- No deadlock! + go func() { + timer.Stop() + fn() + }() + }, + } + timer.Reset(duration) + return timer +} + +// TcpipTimer is a resettable timer with variable duration expirations. +// Implements tcpip.Timer, which does not define a Destroy method; instead, all +// resources are released after timer expiration and calls to Timer.Stop. +// +// Must be created by AfterFunc. +type TcpipTimer struct { + // clock is the time source. clock is immutable. + clock Clock + + // notifier is called when the Timer expires. notifier is immutable. + notifier functionNotifier + + // mu protects t. + mu sync.Mutex + + // t stores the latest running Timer. This is replaced whenever Reset is + // called since Timer cannot be restarted once it has been Destroyed by Stop. + // + // This field is nil iff Stop has been called. + t *Timer +} + +// Stop implements tcpip.Timer.Stop. +func (r *TcpipTimer) Stop() bool { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + return false + } + _, lastSetting := r.t.Swap(Setting{}) + r.t.Destroy() + r.t = nil + return lastSetting.Enabled +} + +// Reset implements tcpip.Timer.Reset. +func (r *TcpipTimer) Reset(d time.Duration) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + r.t = NewTimer(r.clock, &r.notifier) + } + + r.t.Swap(Setting{ + Enabled: true, + Period: 0, + Next: r.clock.Now().Add(d), + }) +} + +// functionNotifier is a TimerListener that runs a function. +// +// functionNotifier cannot be saved or loaded. +type functionNotifier struct { + fn func() +} + +// Notify implements ktime.TimerListener.Notify. +func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) { + f.fn() + return Setting{}, false +} + +// Destroy implements ktime.TimerListener.Destroy. +func (f *functionNotifier) Destroy() {} diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 0adf25691..5f3908d8b 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -210,9 +210,6 @@ func (t *Timekeeper) startUpdater() { p.realtimeBaseRef = int64(realtimeParams.BaseRef) p.realtimeFrequency = realtimeParams.Frequency } - - log.Debugf("Updating VDSO parameters: %+v", p) - return p }); err != nil { log.Warningf("Unable to update VDSO parameter page: %v", err) diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index c6aa65f28..34bdb0b69 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -30,9 +30,6 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsbridge", "//pkg/sentry/kernel/auth", "//pkg/sentry/limits", @@ -45,6 +42,5 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/usermem", - "//pkg/waiter", ], ) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 616fafa2c..ddeaff3db 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -90,14 +90,23 @@ type elfInfo struct { sharedObject bool } +// fullReader interface extracts the ReadFull method from fsbridge.File so that +// client code does not need to define an entire fsbridge.File when only read +// functionality is needed. +// +// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap +// vfs.FileDescription's PRead/Read instead. +type fullReader interface { + // ReadFull is the same as fsbridge.File.ReadFull. + ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) +} + // parseHeader parse the ELF header, verifying that this is a supported ELF // file and returning the ELF program headers. // // This is similar to elf.NewFile, except that it is more strict about what it // accepts from the ELF, and it doesn't parse unnecessary parts of the file. -// -// ctx may be nil if f does not need it. -func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) { +func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // Check ident first; it will tell us the endianness of the rest of the // structs. var ident [elf.EI_NIDENT]byte diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 88449fe95..986c7fb4d 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -80,22 +79,6 @@ type LoadArgs struct { Features *cpuid.FeatureSet } -// readFull behaves like io.ReadFull for an *fs.File. -func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var total int64 - for dst.NumBytes() > 0 { - n, err := f.Preadv(ctx, dst, offset+total) - total += n - if err == io.EOF && total != 0 { - return total, io.ErrUnexpectedEOF - } else if err != nil { - return total, err - } - dst = dst.DropFirst64(n) - } - return total, nil -} - // openPath opens args.Filename and checks that it is valid for loading. // // openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not @@ -238,14 +221,14 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V // Load the executable itself. loaded, ac, file, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer file.DecRef() // Load the VDSO. vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index 165869028..05a294fe6 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -26,10 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -37,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) const vdsoPrelink = 0xffffffffff700000 @@ -55,52 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} { } } -// byteReader implements fs.FileOperations for reading from a []byte source. -type byteReader struct { - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` - +type byteFullReader struct { data []byte } -var _ fs.FileOperations = (*byteReader)(nil) - -// newByteReaderFile creates a fake file to read data from. -// -// TODO(gvisor.dev/issue/2921): Convert to VFS2. -func newByteReaderFile(ctx context.Context, data []byte) *fs.File { - // Create a fake inode. - inode := fs.NewInode( - ctx, - &fsutil.SimpleFileInode{}, - fs.NewPseudoMountSource(ctx), - fs.StableAttr{ - Type: fs.Anonymous, - DeviceID: anon.PseudoDevice.DeviceID(), - InodeID: anon.PseudoDevice.NextIno(), - BlockSize: usermem.PageSize, - }) - - // Use the fake inode to create a fake dirent. - dirent := fs.NewTransientDirent(inode) - defer dirent.DecRef() - - // Use the fake dirent to make a fake file. - flags := fs.FileFlags{Read: true, Pread: true} - return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{ - data: data, - }) -} - -func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { +func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } @@ -111,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ return int64(n), err } -func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { - panic("Write not supported") -} - // validateVDSO checks that the VDSO can be loaded by loadVDSO. // // VDSOs are special (see below). Since we are going to map the VDSO directly @@ -130,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq // * PT_LOAD segments don't extend beyond the end of the file. // // ctx may be nil if f does not need it. -func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) { +func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) { info, err := parseHeader(ctx, f) if err != nil { log.Infof("Unable to parse VDSO header: %v", err) @@ -248,13 +198,12 @@ func getSymbolValueFromVDSO(symbol string) (uint64, error) { // PrepareVDSO validates the system VDSO and returns a VDSO, containing the // param page for updating by the kernel. -func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) { - vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin)) +func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { + vdsoFile := &byteFullReader{data: vdsoBin} // First make sure the VDSO is valid. vdsoFile does not use ctx, so a // nil context can be passed. info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin))) - vdsoFile.DecRef() if err != nil { return nil, err } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 4792454c4..10a10bfe2 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -60,6 +60,7 @@ go_library( go_test( name = "kvm_test", srcs = [ + "kvm_amd64_test.go", "kvm_test.go", "virtual_map_test.go", ], diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go new file mode 100644 index 000000000..c0b4fd374 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +func TestSegments(t *testing.T) { + applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + testutil.SetTestSegments(regs) + for { + var si arch.SignalInfo + if _, err := c.SwitchToUser(ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: dummyFPState, + PageTables: pt, + FullRestore: true, + }, &si); err == platform.ErrContextInterrupt { + continue // Retry. + } else if err != nil { + t.Errorf("application segment check with full restore got unexpected error: %v", err) + } + if err := testutil.CheckTestSegments(regs); err != nil { + t.Errorf("application segment check with full restore failed: %v", err) + } + break // Done. + } + return false + }) +} diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 6c8f4fa28..45b3180f1 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -262,30 +262,6 @@ func TestRegistersFault(t *testing.T) { }) } -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - func TestBounce(t *testing.T) { applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { go func() { diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 8bed34922..3de309c1a 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -78,19 +78,6 @@ func (c *vCPU) initArchState() error { return err } - // sctlr_el1 - regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.getOneRegister(®Get); err != nil { - return err - } - - dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) - data = dataGet - reg.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.setOneRegister(®); err != nil { - return err - } - // tcr_el1 data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS reg.id = _KVM_ARM64_REGS_TCR_EL1 diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 0bebee852..07658144e 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -104,3 +104,9 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 TWIDDLE_REGS() SVC RET // never reached + +TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 + TWIDDLE_REGS() + // Branch to Register branches unconditionally to an address in <Rn>. + JMP (R4) // <=> br x4, must fault + RET // never reached diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2bc5f3ecd..6ed73699b 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -40,6 +40,14 @@ #define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) +// sctlr_el1: system control register el1. +#define SCTLR_M 1 << 0 +#define SCTLR_C 1 << 2 +#define SCTLR_I 1 << 12 +#define SCTLR_UCT 1 << 15 + +#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT) + // Saves a register set. // // This is a macro because it may need to executed in contents where a stack is @@ -496,6 +504,11 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 IRQ_DISABLE + + // Init. + MOVD $SCTLR_EL1_DEFAULT, R1 + MSR R1, SCTLR_EL1 + MOVD R8, RSV_REG ORR $0xffff000000000000, RSV_REG, RSV_REG WORD $0xd518d092 //MSR R18, TPIDR_EL1 diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index ccacaea6b..fca3a5478 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -58,7 +58,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate &= ^uint64(UserFlagsClear) regs.Pstate |= UserFlagsSet + + SetTLS(regs.TPIDR_EL0) + kernelExitToEl0() + + regs.TPIDR_EL0 = GetTLS() + vector = c.vecCode // Perform the switch. diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index ff81ea6e6..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -40,6 +40,8 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index fda19e7bb..532a1ea5d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 8f192c62f..8a1d52ebf 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -71,6 +71,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in DenyPWrite: true, UseDentryMetadata: true, }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) return nil, syserr.FromError(err) } return vfsfd, nil diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f7abe77d3..d9394055d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP } if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) info.NumEntries++ } @@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<<hook) != 0 { hk := hookFromLinux(hook) + table.BuiltinChains[hk] = stack.HookUnset + table.Underflows[hk] = stack.HookUnset for offset, ruleIdx := range offsets { if offset == replace.HookEntry[hook] { table.BuiltinChains[hk] = ruleIdx @@ -456,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Add the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(stack.UserChainTarget) - if !ok { + if _, ok := rule.Target.(stack.UserChainTarget); !ok { continue } @@ -473,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { nflog("user chain's first node must have no matchers") return syserr.ErrInvalidArgument } - table.UserChains[target.Name] = ruleIdx + 1 } // Set each jump to point to the appropriate rule. Right now they hold byte @@ -499,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if ruleIdx == stack.HookUnset { + continue + } if !isUnconditionalAccept(table.Rules[ruleIdx]) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument @@ -512,9 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - stk.IPTables().ReplaceTable(replace.Name.String(), table) - - return nil + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index d5ca3ac56..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,6 +36,8 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..98ca7add0 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ea6ebd0e2..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e7d2c83d7..9856ab8c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -61,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -192,6 +195,7 @@ var Metrics = tcpip.Stats{ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -296,8 +300,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -446,8 +451,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -894,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -904,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -940,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -955,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -965,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -998,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1009,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1019,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1034,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1050,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1066,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1077,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1088,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1096,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1111,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1122,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1133,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1147,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1155,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1167,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1178,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1187,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1198,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1209,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1220,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1231,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1243,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1255,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1267,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1279,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1293,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1328,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1340,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1352,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1363,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1375,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1384,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1395,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1403,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1428,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIPv6(t, name) @@ -1437,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1450,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1466,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1480,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1491,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1516,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1527,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIP(t, name) @@ -1753,6 +1824,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -2112,13 +2188,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -2439,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2494,6 +2596,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2732,7 +2839,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2773,18 +2880,19 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ @@ -2800,7 +2908,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc }) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf @@ -2874,7 +2982,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2897,21 +3005,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2920,7 +3035,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2931,32 +3046,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2973,6 +3088,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index d65a89316..a9025b0ec 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 548442b96..67737ae87 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,6 +15,8 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fcd7f9d7f..d112757fb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cca5e70f1..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -35,5 +35,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..0482d33cf 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index ff2149250..05c16fcfe 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..4a9b04fd0 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -99,5 +99,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ea4f9b1a7..80c65164a 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{ 270: syscalls.Supported("pselect", Pselect), 271: syscalls.Supported("ppoll", Ppoll), 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 273: syscalls.Supported("set_robust_list", SetRobustList), + 274: syscalls.Supported("get_robust_list", GetRobustList), 275: syscalls.Supported("splice", Splice), 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b68261f72..f04d78856 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall switch cmd { case linux.FUTEX_WAIT: // WAIT uses a relative timeout. - mask = ^uint32(0) + mask = linux.FUTEX_BITSET_MATCH_ANY var timeoutDur time.Duration if !forever { timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond @@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.ENOSYS } } + +// SetRobustList implements linux syscall set_robust_list(2). +func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + head := args[0].Pointer() + length := args[1].SizeT() + + if length != uint(linux.SizeOfRobustListHead) { + return 0, nil, syserror.EINVAL + } + t.SetRobustList(head) + return 0, nil, nil +} + +// GetRobustList implements linux syscall get_robust_list(2). +func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + tid := args[0].Int() + head := args[1].Pointer() + size := args[2].Pointer() + + if tid < 0 { + return 0, nil, syserror.EINVAL + } + + ot := t + if tid != 0 { + if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil { + return 0, nil, syserror.ESRCH + } + } + + // Copy out head pointer. + if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil { + return 0, nil, err + } + + // Copy out size, which is a constant. + if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil { + return 0, nil, err + } + + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..414fce8e3 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -29,6 +29,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 0c740335b..64696b438 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -72,5 +72,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index b12b5967b..6b14c2bef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -107,7 +107,7 @@ func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall addr := args[0].Pointer() mode := args[1].ModeT() dev := args[2].Uint() - return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev) + return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev) } // Mknodat implements Linux syscall mknodat(2). @@ -116,10 +116,10 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca addr := args[1].Pointer() mode := args[2].ModeT() dev := args[3].Uint() - return 0, nil, mknodat(t, dirfd, addr, mode, dev) + return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev) } -func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error { +func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error { path, err := copyInPath(t, addr) if err != nil { return err @@ -129,9 +129,14 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint return err } defer tpop.Release() + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + if mode.FileType() == 0 { + mode |= linux.ModeRegular + } major, minor := linux.DecodeDeviceID(dev) return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{ - Mode: linux.FileMode(mode &^ t.FSContext().Umask()), + Mode: mode &^ linux.FileMode(t.FSContext().Umask()), DevMajor: uint32(major), DevMinor: minor, }) diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index adeaa39cc..ea337de7c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Silently allow MS_NOSUID, since we don't implement set-id bits // anyway. - const unsupportedFlags = linux.MS_NODEV | - linux.MS_NODIRATIME | linux.MS_STRICTATIME + const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME // Linux just allows passing any flags to mount(2) - it won't fail when // unknown or unsupported flags are passed. Since we don't implement @@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if flags&linux.MS_NOEXEC == linux.MS_NOEXEC { opts.Flags.NoExec = true } + if flags&linux.MS_NODEV == linux.MS_NODEV { + opts.Flags.NoDev = true + } + if flags&linux.MS_NOSUID == linux.MS_NOSUID { + opts.Flags.NoSUID = true + } if flags&linux.MS_RDONLY == linux.MS_RDONLY { opts.ReadOnly = true } diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 09ecfed26..6daedd173 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -178,6 +178,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Mask: linux.STATX_SIZE, Size: uint64(length), }, + NeedWritePerm: true, }) return 0, nil, handleSetSizeError(t, err) } @@ -197,6 +198,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } defer file.DecRef() + if !file.IsWritable() { + return 0, nil, syserror.EINVAL + } + err := file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_SIZE, diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..8096a8f9c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 945a364a7..63ab11f8c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -15,12 +15,15 @@ package vfs2 import ( + "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Move data. var ( - n int64 - err error - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { // If both input and output are pipes, delegate to the pipe - // implementation. Otherwise, exactly one end is a pipe, which we - // ensure is consistently ordered after the non-pipe FD's locks by - // passing the pipe FD as usermem.IO to the non-pipe end. + // implementation. Otherwise, exactly one end is a pipe, which + // we ensure is consistently ordered after the non-pipe FD's + // locks by passing the pipe FD as usermem.IO to the non-pipe + // end. switch { case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) @@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } else { n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } + default: + panic("not possible") } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } - - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the splice operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. - } - if err = t.Block(inCh); err != nil { - break - } - } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + if err = dw.waitForBoth(t); err != nil { + break } } @@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Copy data. var ( - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { - n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 { - return uintptr(n), nil, nil + n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + if err = dw.waitForBoth(t); err != nil { + break + } + } + if n == 0 { + return 0, nil, err + } + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Sendfile implements linux system call sendfile(2). +func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + outFD := args[0].Int() + inFD := args[1].Int() + offsetAddr := args[2].Pointer() + count := int64(args[3].SizeT()) + + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + if !inFile.IsReadable() { + return 0, nil, syserror.EBADF + } + + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + if !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // Verify that the outFile Append flag is not set. + if outFile.StatusFlags()&linux.O_APPEND != 0 { + return 0, nil, syserror.EINVAL + } + + // Verify that inFile is a regular file or block device. This is a + // requirement; the same check appears in Linux + // (fs/splice.c:splice_direct_to_actor). + if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil { + return 0, nil, err + } else if stat.Mask&linux.STATX_TYPE == 0 || + (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { + return 0, nil, syserror.EINVAL + } + + // Copy offset if it exists. + offset := int64(-1) + if offsetAddr != 0 { + if inFile.Options().DenyPRead { + return 0, nil, syserror.ESPIPE } - if err != syserror.ErrWouldBlock || nonBlock { + if _, err := t.CopyIn(offsetAddr, &offset); err != nil { return 0, nil, err } + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if offset+count < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Validate count. This must come after offset checks. + if count < 0 { + return 0, nil, syserror.EINVAL + } + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the tee operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Copy data. + var ( + n int64 + err error + ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + // Reading from input file should never block, since it is regular or + // block device. We only need to check if writing to the output file + // can block. + nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 + if outIsPipe { + for n < count { + var spliceN int64 + if offset != -1 { + spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{}) + offset += spliceN + } else { + spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } - if err := t.Block(inCh); err != nil { - return 0, nil, err + n += spliceN + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) + } + if err != nil { + break } } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. + } else { + // Read inFile to buffer, then write the contents to outFile. + buf := make([]byte, count) + for n < count { + var readN int64 + if offset != -1 { + readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) + offset += readN + } else { + readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + } + if readN == 0 && err == io.EOF { + // We reached the end of the file. Eat the + // error and exit the loop. + err = nil + break } - if err := t.Block(outCh); err != nil { - return 0, nil, err + n += readN + if err != nil { + break + } + + // Write all of the bytes that we read. This may need + // multiple write calls to complete. + wbuf := buf[:n] + for len(wbuf) > 0 { + var writeN int64 + writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) + wbuf = wbuf[writeN:] + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForOut(t) + } + if err != nil { + // We didn't complete the write. Only + // report the bytes that were actually + // written, and rewind the offset. + notWritten := int64(len(wbuf)) + n -= notWritten + if offset != -1 { + offset -= notWritten + } + break + } + } + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) } + if err != nil { + break + } + } + } + + if offsetAddr != 0 { + // Copy out the new offset. + if _, err := t.CopyOut(offsetAddr, offset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not +// thread-safe, and does not take a reference on the vfs.FileDescriptions. +// +// Users must call destroy() when finished. +type dualWaiter struct { + inFile *vfs.FileDescription + outFile *vfs.FileDescription + + inW waiter.Entry + inCh chan struct{} + outW waiter.Entry + outCh chan struct{} +} + +// waitForBoth waits for both dw.inFile and dw.outFile to be ready. +func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { + if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if dw.inCh == nil { + dw.inW, dw.inCh = waiter.NewChannelEntry(nil) + dw.inFile.EventRegister(&dw.inW, eventMaskRead) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.inCh); err != nil { + return err + } + } + return dw.waitForOut(t) +} + +// waitForOut waits for dw.outfile to be read. +func (dw *dualWaiter) waitForOut(t *kernel.Task) error { + if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready now. Try again before blocking. + return nil } + if err := t.Block(dw.outCh); err != nil { + return err + } + } + return nil +} + +// destroy cleans up resources help by dw. No more calls to wait* can occur +// after destroy is called. +func (dw *dualWaiter) destroy() { + if dw.inCh != nil { + dw.inFile.EventUnregister(&dw.inW) + dw.inCh = nil + } + if dw.outCh != nil { + dw.outFile.EventUnregister(&dw.outW) + dw.outCh = nil } + dw.inFile = nil + dw.outFile = nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 8f497ecc7..1b2cfad7d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -44,7 +44,7 @@ func Override() { s.Table[23] = syscalls.Supported("select", Select) s.Table[32] = syscalls.Supported("dup", Dup) s.Table[33] = syscalls.Supported("dup2", Dup2) - delete(s.Table, 40) // sendfile + s.Table[40] = syscalls.Supported("sendfile", Sendfile) s.Table[41] = syscalls.Supported("socket", Socket) s.Table[42] = syscalls.Supported("connect", Connect) s.Table[43] = syscalls.Supported("accept", Accept) diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go index 65868cb26..cd1b95117 100644 --- a/pkg/sentry/time/parameters.go +++ b/pkg/sentry/time/parameters.go @@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par // // The log level is determined by the error severity. func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) { - fn := log.Debugf - if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() { + magNS := int64(errorNS.Magnitude()) + if magNS <= 10*time.Microsecond.Nanoseconds() { + // Don't log small errors. + return + } + fn := log.Infof + if magNS > time.Millisecond.Nanoseconds() { + // Upgrade large errors to warning. fn = log.Warningf - } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() { - fn = log.Infof } fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency) diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index c2e21ac5f..167b731ac 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -179,12 +179,12 @@ func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask { return mask & ready } -// PRead implements FileDescriptionImpl. +// PRead implements FileDescriptionImpl.PRead. func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.ESPIPE } -// PWrite implements FileDescriptionImpl. +// PWrite implements FileDescriptionImpl.PWrite. func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { return 0, syserror.ESPIPE } @@ -243,7 +243,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt return writeLen, nil } -// Ioctl implements fs.FileOperations.Ioctl. +// Ioctl implements FileDescriptionImpl.Ioctl. func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch args[1].Int() { case linux.FIONREAD: diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index f223aeda8..dfc8573fd 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -79,6 +79,17 @@ type MountFlags struct { // NoATime is equivalent to MS_NOATIME and indicates that the // filesystem should not update access time in-place. NoATime bool + + // NoDev is equivalent to MS_NODEV and indicates that the + // filesystem should not allow access to devices (special files). + // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE + // filesystems. + NoDev bool + + // NoSUID is equivalent to MS_NOSUID and indicates that the + // filesystem should not honor set-user-ID and set-group-ID bits or + // file capabilities when executing programs. + NoSUID bool } // MountOptions contains options to VirtualFilesystem.MountAt(). @@ -153,6 +164,12 @@ type SetStatOptions struct { // == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask // instead). Stat linux.Statx + + // NeedWritePerm indicates that write permission on the file is needed for + // this operation. This is needed for truncate(2) (note that ftruncate(2) + // does not require the same check--instead, it checks that the fd is + // writable). + NeedWritePerm bool } // BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt() diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 9cb050597..33389c1df 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -183,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { // CheckSetStat checks that creds has permission to change the metadata of a // file with the given permissions, UID, and GID as specified by stat, subject // to the rules of Linux's fs/attr.c:setattr_prepare(). -func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { +func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { + stat := &opts.Stat if stat.Mask&linux.STATX_SIZE != 0 { limit, err := CheckLimit(ctx, 0, int64(stat.Size)) if err != nil { @@ -215,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat return syserror.EPERM } } + if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) { + if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil { + return err + } + } if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { if !CanActAsOwner(creds, kuid) { if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 58c7ad778..522e27475 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -123,6 +123,9 @@ type VirtualFilesystem struct { // Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes. func (vfs *VirtualFilesystem) Init() error { + if vfs.mountpoints != nil { + panic("VFS already initialized") + } vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{}) vfs.devices = make(map[devTuple]*registeredDevice) vfs.anonBlockDevMinorNext = 1 diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD new file mode 100644 index 000000000..f08599ebd --- /dev/null +++ b/pkg/shim/runsc/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "runsc", + srcs = [ + "runsc.go", + "utils.go", + ], + visibility = ["//:sandbox"], + deps = [ + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go new file mode 100644 index 000000000..c5cf68efa --- /dev/null +++ b/pkg/shim/runsc/runsc.go @@ -0,0 +1,514 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "syscall" + "time" + + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +var Monitor runc.ProcessMonitor = runc.Monitor + +// DefaultCommand is the default command for Runsc. +const DefaultCommand = "runsc" + +// Runsc is the client to the runsc cli. +type Runsc struct { + Command string + PdeathSignal syscall.Signal + Setpgid bool + Root string + Log string + LogFormat runc.Format + Config map[string]string +} + +// List returns all containers created inside the provided runsc root directory. +func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { + data, err := cmdOutput(r.command(context, "list", "--format=json"), false) + if err != nil { + return nil, err + } + var out []*runc.Container + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return out, nil +} + +// State returns the state for the container provided by id. +func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) { + data, err := cmdOutput(r.command(context, "state", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var c runc.Container + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil +} + +type CreateOpts struct { + runc.IO + ConsoleSocket runc.ConsoleSocket + + // PidFile is a path to where a pid file should be created. + PidFile string + + // UserLog is a path to where runsc user log should be generated. + UserLog string +} + +func (o *CreateOpts) args() (out []string, err error) { + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.UserLog != "" { + out = append(out, "--user-log", o.UserLog) + } + return out, nil +} + +// Create creates a new container and returns its pid if it was created successfully. +func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error { + args := []string{"create", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +// Start will start an already created container. +func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error { + cmd := r.command(context, "start", id) + if cio != nil { + cio.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if cio != nil { + if c, ok := cio.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +type waitResult struct { + ID string `json:"id"` + ExitStatus int `json:"exitStatus"` +} + +// Wait will wait for a running container, and return its exit status. +// +// TODO(random-liu): Add exec process support. +func (r *Runsc) Wait(context context.Context, id string) (int, error) { + data, err := cmdOutput(r.command(context, "wait", id), true) + if err != nil { + return 0, fmt.Errorf("%s: %s", err, data) + } + var res waitResult + if err := json.Unmarshal(data, &res); err != nil { + return 0, err + } + return res.ExitStatus, nil +} + +type ExecOpts struct { + runc.IO + PidFile string + InternalPidFile string + ConsoleSocket runc.ConsoleSocket + Detach bool +} + +func (o *ExecOpts) args() (out []string, err error) { + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.Detach { + out = append(out, "--detach") + } + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.InternalPidFile != "" { + abs, err := filepath.Abs(o.InternalPidFile) + if err != nil { + return nil, err + } + out = append(out, "--internal-pid-file", abs) + } + return out, nil +} + +// Exec executes an additional process inside the container based on a full OCI +// Process specification. +func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error { + f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process") + if err != nil { + return err + } + defer os.Remove(f.Name()) + err = json.NewEncoder(f).Encode(spec) + f.Close() + if err != nil { + return err + } + args := []string{"exec", "--process", f.Name()} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err +} + +// Run runs the create, start, delete lifecycle of the container and returns +// its exit status after it has exited. +func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) { + args := []string{"run", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return -1, err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + ec, err := Monitor.Start(cmd) + if err != nil { + return -1, err + } + return Monitor.Wait(cmd, ec) +} + +type DeleteOpts struct { + Force bool +} + +func (o *DeleteOpts) args() (out []string) { + if o.Force { + out = append(out, "--force") + } + return out +} + +// Delete deletes the container. +func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error { + args := []string{"delete"} + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id)...)) +} + +// KillOpts specifies options for killing a container and its processes. +type KillOpts struct { + All bool + Pid int +} + +func (o *KillOpts) args() (out []string) { + if o.All { + out = append(out, "--all") + } + if o.Pid != 0 { + out = append(out, "--pid", strconv.Itoa(o.Pid)) + } + return out +} + +// Kill sends the specified signal to the container. +func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error { + args := []string{ + "kill", + } + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...)) +} + +// Stats return the stats for a container like cpu, memory, and I/O. +func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { + cmd := r.command(context, "events", "--stats", id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + defer func() { + rd.Close() + Monitor.Wait(cmd, ec) + }() + var e runc.Event + if err := json.NewDecoder(rd).Decode(&e); err != nil { + return nil, err + } + return e.Stats, nil +} + +// Events returns an event stream from runsc for a container with stats and OOM notifications. +func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) { + cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + rd.Close() + return nil, err + } + var ( + dec = json.NewDecoder(rd) + c = make(chan *runc.Event, 128) + ) + go func() { + defer func() { + close(c) + rd.Close() + Monitor.Wait(cmd, ec) + }() + for { + var e runc.Event + if err := dec.Decode(&e); err != nil { + if err == io.EOF { + return + } + e = runc.Event{ + Type: "error", + Err: err, + } + } + c <- &e + } + }() + return c, nil +} + +// Ps lists all the processes inside the container returning their pids. +func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var pids []int + if err := json.Unmarshal(data, &pids); err != nil { + return nil, err + } + return pids, nil +} + +// Top lists all the processes inside the container returning the full ps data. +func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + + topResults, err := runc.ParsePSOutput(data) + if err != nil { + return nil, fmt.Errorf("%s: ", err) + } + return topResults, nil +} + +func (r *Runsc) args() []string { + var args []string + if r.Root != "" { + args = append(args, fmt.Sprintf("--root=%s", r.Root)) + } + if r.Log != "" { + args = append(args, fmt.Sprintf("--log=%s", r.Log)) + } + if r.LogFormat != "" { + args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat)) + } + for k, v := range r.Config { + args = append(args, fmt.Sprintf("--%s=%s", k, v)) + } + return args +} + +// runOrError will run the provided command. +// +// If an error is encountered and neither Stdout or Stderr was set the error +// will be returned in the format of <error>: <stderr>. +func (r *Runsc) runOrError(cmd *exec.Cmd) error { + if cmd.Stdout != nil || cmd.Stderr != nil { + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err + } + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil +} + +func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd { + command := r.Command + if command == "" { + command = DefaultCommand + } + cmd := exec.CommandContext(context, command, append(r.args(), args...)...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: r.Setpgid, + } + if r.PdeathSignal != 0 { + cmd.SysProcAttr.Pdeathsig = r.PdeathSignal + } + + return cmd +} + +func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) { + b := getBuf() + defer putBuf(b) + + cmd.Stdout = b + if combined { + cmd.Stderr = b + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return b.Bytes(), err +} diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go new file mode 100644 index 000000000..c514b3bc7 --- /dev/null +++ b/pkg/shim/runsc/utils.go @@ -0,0 +1,44 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "bytes" + "strings" + "sync" +) + +var bytesBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, +} + +func getBuf() *bytes.Buffer { + return bytesBufferPool.Get().(*bytes.Buffer) +} + +func putBuf(b *bytes.Buffer) { + b.Reset() + bytesBufferPool.Put(b) +} + +// FormatLogPath parses runsc config, and fill in %ID% in the log path. +func FormatLogPath(id string, config map[string]string) { + if path, ok := config["debug-log"]; ok { + config["debug-log"] = strings.Replace(path, "%ID%", id, -1) + } +} diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD new file mode 100644 index 000000000..4377306af --- /dev/null +++ b/pkg/shim/v1/proc/BUILD @@ -0,0 +1,36 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "proc", + srcs = [ + "deleted_state.go", + "exec.go", + "exec_state.go", + "init.go", + "init_state.go", + "io.go", + "process.go", + "types.go", + "utils.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go new file mode 100644 index 000000000..d9b970c4d --- /dev/null +++ b/pkg/shim/v1/proc/deleted_state.go @@ -0,0 +1,49 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type deletedState struct{} + +func (*deletedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a deleted process.ss") +} + +func (*deletedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a deleted process.ss") +} + +func (*deletedState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error { + return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) SetExited(status int) {} + +func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a deleted state") +} diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go new file mode 100644 index 000000000..1d1d90488 --- /dev/null +++ b/pkg/shim/v1/proc/exec.go @@ -0,0 +1,281 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +type execProcess struct { + wg sync.WaitGroup + + execState execState + + mu sync.Mutex + id string + console console.Console + io runc.IO + status int + exited time.Time + pid int + internalPid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + path string + spec specs.Process + + parent *Init + waitBlock chan struct{} +} + +func (e *execProcess) Wait() { + <-e.waitBlock +} + +func (e *execProcess) ID() string { + return e.id +} + +func (e *execProcess) Pid() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.pid +} + +func (e *execProcess) ExitStatus() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.status +} + +func (e *execProcess) ExitedAt() time.Time { + e.mu.Lock() + defer e.mu.Unlock() + return e.exited +} + +func (e *execProcess) SetExited(status int) { + e.mu.Lock() + defer e.mu.Unlock() + + e.execState.SetExited(status) +} + +func (e *execProcess) setExited(status int) { + e.status = status + e.exited = time.Now() + e.parent.Platform.ShutdownConsole(context.Background(), e.console) + close(e.waitBlock) +} + +func (e *execProcess) Delete(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Delete(ctx) +} + +func (e *execProcess) delete(ctx context.Context) error { + e.wg.Wait() + if e.io != nil { + for _, c := range e.closers { + c.Close() + } + e.io.Close() + } + pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + // silently ignore error + os.Remove(pidfile) + internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + // silently ignore error + os.Remove(internalPidfile) + return nil +} + +func (e *execProcess) Resize(ws console.WinSize) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Resize(ws) +} + +func (e *execProcess) resize(ws console.WinSize) error { + if e.console == nil { + return nil + } + return e.console.Resize(ws) +} + +func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Kill(ctx, sig, false) +} + +func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error { + internalPid := e.internalPid + if internalPid != 0 { + if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{ + Pid: internalPid, + }); err != nil { + // If this returns error, consider the process has + // already stopped. + // + // TODO: Fix after signal handling is fixed. + return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) + } + } + return nil +} + +func (e *execProcess) Stdin() io.Closer { + return e.stdin +} + +func (e *execProcess) Stdio() stdio.Stdio { + return e.stdio +} + +func (e *execProcess) Start(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Start(ctx) +} + +func (e *execProcess) start(ctx context.Context) (err error) { + var ( + socket *runc.Socket + pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + ) + if e.stdio.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create runc console socket: %w", err) + } + defer socket.Close() + } else if e.stdio.IsNull() { + if e.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil { + return fmt.Errorf("failed to create runc io pipes: %w", err) + } + } + opts := &runsc.ExecOpts{ + PidFile: pidfile, + InternalPidFile: internalPidfile, + IO: e.io, + Detach: true, + } + if socket != nil { + opts.ConsoleSocket = socket + } + eventCh := e.parent.Monitor.Subscribe() + defer func() { + // Unsubscribe if an error is returned. + if err != nil { + e.parent.Monitor.Unsubscribe(eventCh) + } + }() + if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil { + close(e.waitBlock) + return e.parent.runtimeError(err, "OCI runtime exec failed") + } + if e.stdio.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err) + } + e.closers = append(e.closers, sc) + e.stdin = sc + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + } else if !e.stdio.IsNull() { + if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(opts.PidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err) + } + e.pid = pid + internalPid, err := runc.ReadPidFile(opts.InternalPidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err) + } + e.internalPid = internalPid + go func() { + defer e.parent.Monitor.Unsubscribe(eventCh) + for event := range eventCh { + if event.Pid == e.pid { + ExitCh <- Exit{ + Timestamp: event.Timestamp, + ID: e.id, + Status: event.Status, + } + break + } + } + }() + return nil +} + +func (e *execProcess) Status(ctx context.Context) (string, error) { + e.mu.Lock() + defer e.mu.Unlock() + // if we don't have a pid then the exec process has just been created + if e.pid == 0 { + return "created", nil + } + // if we have a pid and it can be signaled, the process is running + // TODO(random-liu): Use `runsc kill --pid`. + if err := unix.Kill(e.pid, 0); err == nil { + return "running", nil + } + // else if we have a pid but it can nolonger be signaled, it has stopped + return "stopped", nil +} diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go new file mode 100644 index 000000000..4dcda8b44 --- /dev/null +++ b/pkg/shim/v1/proc/exec_state.go @@ -0,0 +1,154 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" +) + +type execState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type execCreatedState struct { + p *execProcess +} + +func (s *execCreatedState) transition(name string) error { + switch name { + case "running": + s.p.execState = &execRunningState{p: s.p} + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execCreatedState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execCreatedState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + return err + } + return s.transition("running") +} + +func (s *execCreatedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execCreatedState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execRunningState struct { + p *execProcess +} + +func (s *execRunningState) transition(name string) error { + switch name { + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execRunningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execRunningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process") +} + +func (s *execRunningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process") +} + +func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execRunningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execStoppedState struct { + p *execProcess +} + +func (s *execStoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execStoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *execStoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process") +} + +func (s *execStoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execStoppedState) SetExited(status int) { + // no op +} diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go new file mode 100644 index 000000000..dab3123d6 --- /dev/null +++ b/pkg/shim/v1/proc/init.go @@ -0,0 +1,460 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +// InitPidFile name of the file that contains the init pid. +const InitPidFile = "init.pid" + +// Init represents an initial process for a container. +type Init struct { + wg sync.WaitGroup + initState initState + + // mu is used to ensure that `Start()` and `Exited()` calls return in + // the right order when invoked in separate go routines. This is the + // case within the shim implementation as it makes use of the reaper + // interface. + mu sync.Mutex + + waitBlock chan struct{} + + WorkDir string + + id string + Bundle string + console console.Console + Platform stdio.Platform + io runc.IO + runtime *runsc.Runsc + status int + exited time.Time + pid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + Rootfs string + IoUID int + IoGID int + Sandbox bool + UserLog string + Monitor ProcessMonitor +} + +// NewRunsc returns a new runsc instance for a process. +func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc { + if root == "" { + root = RunscRoot + } + return &runsc.Runsc{ + Command: runtime, + PdeathSignal: syscall.SIGKILL, + Log: filepath.Join(path, "log.json"), + LogFormat: runc.JSON, + Root: filepath.Join(root, namespace), + Config: config, + } +} + +// New returns a new init process. +func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init { + p := &Init{ + id: id, + runtime: runtime, + stdio: stdio, + status: 0, + waitBlock: make(chan struct{}), + } + p.initState = &createdState{p: p} + return p +} + +// Create the process with the provided config. +func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) { + var socket *runc.Socket + if r.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create OCI runtime console socket: %w", err) + } + defer socket.Close() + } else if hasNoIO(r) { + if p.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil { + return fmt.Errorf("failed to create OCI runtime io pipes: %w", err) + } + } + pidFile := filepath.Join(p.Bundle, InitPidFile) + opts := &runsc.CreateOpts{ + PidFile: pidFile, + } + if socket != nil { + opts.ConsoleSocket = socket + } + if p.Sandbox { + opts.IO = p.io + // UserLog is only useful for sandbox. + opts.UserLog = p.UserLog + } + if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil { + return p.runtimeError(err, "OCI runtime create failed") + } + if r.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err) + } + p.stdin = sc + p.closers = append(p.closers, sc) + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg) + if err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + p.console = console + } else if !hasNoIO(r) { + if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(pidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err) + } + p.pid = pid + return nil +} + +// Wait waits for the process to exit. +func (p *Init) Wait() { + <-p.waitBlock +} + +// ID returns the ID of the process. +func (p *Init) ID() string { + return p.id +} + +// Pid returns the PID of the process. +func (p *Init) Pid() int { + return p.pid +} + +// ExitStatus returns the exit status of the process. +func (p *Init) ExitStatus() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.status +} + +// ExitedAt returns the time when the process exited. +func (p *Init) ExitedAt() time.Time { + p.mu.Lock() + defer p.mu.Unlock() + return p.exited +} + +// Status returns the status of the process. +func (p *Init) Status(ctx context.Context) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + c, err := p.runtime.State(ctx, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return "stopped", nil + } + return "", p.runtimeError(err, "OCI runtime state failed") + } + return p.convertStatus(c.Status), nil +} + +// Start starts the init process. +func (p *Init) Start(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Start(ctx) +} + +func (p *Init) start(ctx context.Context) error { + var cio runc.IO + if !p.Sandbox { + cio = p.io + } + if err := p.runtime.Start(ctx, p.id, cio); err != nil { + return p.runtimeError(err, "OCI runtime start failed") + } + go func() { + status, err := p.runtime.Wait(context.Background(), p.id) + if err != nil { + log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id) + // TODO(random-liu): Handle runsc kill error. + if err := p.killAll(ctx); err != nil { + log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id) + } + status = internalErrorCode + } + ExitCh <- Exit{ + Timestamp: time.Now(), + ID: p.id, + Status: status, + } + }() + return nil +} + +// SetExited set the exit stauts of the init process. +func (p *Init) SetExited(status int) { + p.mu.Lock() + defer p.mu.Unlock() + + p.initState.SetExited(status) +} + +func (p *Init) setExited(status int) { + p.exited = time.Now() + p.status = status + p.Platform.ShutdownConsole(context.Background(), p.console) + close(p.waitBlock) +} + +// Delete deletes the init process. +func (p *Init) Delete(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Delete(ctx) +} + +func (p *Init) delete(ctx context.Context) error { + p.killAll(ctx) + p.wg.Wait() + err := p.runtime.Delete(ctx, p.id, nil) + // ignore errors if a runtime has already deleted the process + // but we still hold metadata and pipes + // + // this is common during a checkpoint, runc will delete the container state + // after a checkpoint and the container will no longer exist within runc + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + err = nil + } else { + err = p.runtimeError(err, "failed to delete task") + } + } + if p.io != nil { + for _, c := range p.closers { + c.Close() + } + p.io.Close() + } + if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount") + if err == nil { + err = fmt.Errorf("failed rootfs umount: %w", err2) + } + } + return err +} + +// Resize resizes the init processes console. +func (p *Init) Resize(ws console.WinSize) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +func (p *Init) resize(ws console.WinSize) error { + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +// Kill kills the init process. +func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Kill(ctx, signal, all) +} + +func (p *Init) kill(context context.Context, signal uint32, all bool) error { + var ( + killErr error + backoff = 100 * time.Millisecond + ) + timeout := 1 * time.Second + for start := time.Now(); time.Now().Sub(start) < timeout; { + c, err := p.runtime.State(context, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + return p.runtimeError(err, "OCI runtime state failed") + } + // For runsc, signal only works when container is running state. + // If the container is not in running state, directly return + // "no such process" + if p.convertStatus(c.Status) == "stopped" { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{ + All: all, + }) + if killErr == nil { + return nil + } + time.Sleep(backoff) + backoff *= 2 + } + return p.runtimeError(killErr, "kill timeout") +} + +// KillAll kills all processes belonging to the init process. +func (p *Init) KillAll(context context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + return p.killAll(context) +} + +func (p *Init) killAll(context context.Context) error { + p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{ + All: true, + }) + // Ignore error handling for `runsc kill --all` for now. + // * If it doesn't return error, it is good; + // * If it returns error, consider the container has already stopped. + // TODO: Fix `runsc kill --all` error handling. + return nil +} + +// Stdin returns the stdin of the process. +func (p *Init) Stdin() io.Closer { + return p.stdin +} + +// Runtime returns the OCI runtime configured for the init process. +func (p *Init) Runtime() *runsc.Runsc { + return p.runtime +} + +// Exec returns a new child process. +func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Exec(ctx, path, r) +} + +// exec returns a new exec'd process. +func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + // process exec request + var spec specs.Process + if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { + return nil, err + } + spec.Terminal = r.Terminal + + e := &execProcess{ + id: r.ID, + path: path, + parent: p, + spec: spec, + stdio: stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }, + waitBlock: make(chan struct{}), + } + e.execState = &execCreatedState{p: e} + return e, nil +} + +// Stdio returns the stdio of the process. +func (p *Init) Stdio() stdio.Stdio { + return p.stdio +} + +func (p *Init) runtimeError(rErr error, msg string) error { + if rErr == nil { + return nil + } + + rMsg, err := getLastRuntimeError(p.runtime) + switch { + case err != nil: + return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err) + case rMsg == "": + return fmt.Errorf("%s: %w", msg, rErr) + default: + return fmt.Errorf("%s: %s", msg, rMsg) + } +} + +func (p *Init) convertStatus(status string) string { + if status == "created" && !p.Sandbox && p.status == internalErrorCode { + // Treat start failure state for non-root container as stopped. + return "stopped" + } + return status +} + +func withConditionalIO(c stdio.Stdio) runc.IOOpt { + return func(o *runc.IOOption) { + o.OpenStdin = c.Stdin != "" + o.OpenStdout = c.Stdout != "" + o.OpenStderr = c.Stderr != "" + } +} diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go new file mode 100644 index 000000000..9233ecc85 --- /dev/null +++ b/pkg/shim/v1/proc/init_state.go @@ -0,0 +1,182 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type initState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Exec(context.Context, string, *ExecConfig) (process.Process, error) + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type createdState struct { + p *Init +} + +func (s *createdState) transition(name string) error { + switch name { + case "running": + s.p.initState = &runningState{p: s.p} + case "stopped": + s.p.initState = &stoppedState{p: s.p} + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *createdState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *createdState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + // Containerd doesn't allow deleting container in created state. + // However, for gvisor, a non-root container in created state can + // only go to running state. If the container can't be started, + // it can only stay in created state, and never be deleted. + // To work around that, we treat non-root container in start failure + // state as stopped. + if !s.p.Sandbox { + s.p.io.Close() + s.p.setExited(internalErrorCode) + if err := s.transition("stopped"); err != nil { + panic(err) + } + } + return err + } + return s.transition("running") +} + +func (s *createdState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *createdState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type runningState struct { + p *Init +} + +func (s *runningState) transition(name string) error { + switch name { + case "stopped": + s.p.initState = &stoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *runningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *runningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process.ss") +} + +func (s *runningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process.ss") +} + +func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *runningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type stoppedState struct { + p *Init +} + +func (s *stoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *stoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *stoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process.ss") +} + +func (s *stoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id) +} + +func (s *stoppedState) SetExited(status int) { + // no op +} + +func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a stopped state") +} diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go new file mode 100644 index 000000000..34d825fb7 --- /dev/null +++ b/pkg/shim/v1/proc/io.go @@ -0,0 +1,162 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "syscall" + + "github.com/containerd/containerd/log" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" +) + +// TODO(random-liu): This file can be a util. + +var bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, +} + +func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error { + var sameFile *countingWriteCloser + for _, i := range []struct { + name string + dest func(wc io.WriteCloser, rc io.Closer) + }{ + { + name: stdout, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil { + log.G(ctx).Warn("error copying stdout") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, { + name: stderr, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil { + log.G(ctx).Warn("error copying stderr") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, + } { + ok, err := isFifo(i.name) + if err != nil { + return err + } + var ( + fw io.WriteCloser + fr io.Closer + ) + if ok { + if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + } else { + if sameFile != nil { + sameFile.count++ + i.dest(sameFile, nil) + continue + } + if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if stdout == stderr { + sameFile = &countingWriteCloser{ + WriteCloser: fw, + count: 1, + } + } + } + i.dest(fw, fr) + } + if stdin == "" { + return nil + } + f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err) + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + + io.CopyBuffer(rio.Stdin(), f, *p) + rio.Stdin().Close() + f.Close() + }() + return nil +} + +// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times. +type countingWriteCloser struct { + io.WriteCloser + count int64 +} + +func (c *countingWriteCloser) Close() error { + if atomic.AddInt64(&c.count, -1) > 0 { + return nil + } + return c.WriteCloser.Close() +} + +// isFifo checks if a file is a fifo. +// +// If the file does not exist then it returns false. +func isFifo(path string) (bool, error) { + stat, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe { + return true, nil + } + return false, nil +} diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go new file mode 100644 index 000000000..d462c3eef --- /dev/null +++ b/pkg/shim/v1/proc/process.go @@ -0,0 +1,37 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "fmt" +) + +// RunscRoot is the path to the root runsc state directory. +const RunscRoot = "/run/containerd/runsc" + +func stateName(v interface{}) string { + switch v.(type) { + case *runningState, *execRunningState: + return "running" + case *createdState, *execCreatedState: + return "created" + case *deletedState: + return "deleted" + case *stoppedState: + return "stopped" + } + panic(fmt.Errorf("invalid state %v", v)) +} diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go new file mode 100644 index 000000000..2b0df4663 --- /dev/null +++ b/pkg/shim/v1/proc/types.go @@ -0,0 +1,69 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "time" + + runc "github.com/containerd/go-runc" + "github.com/gogo/protobuf/types" +) + +// Mount holds filesystem mount configuration. +type Mount struct { + Type string + Source string + Target string + Options []string +} + +// CreateConfig hold task creation configuration. +type CreateConfig struct { + ID string + Bundle string + Runtime string + Rootfs []Mount + Terminal bool + Stdin string + Stdout string + Stderr string + Options *types.Any +} + +// ExecConfig holds exec creation configuration. +type ExecConfig struct { + ID string + Terminal bool + Stdin string + Stdout string + Stderr string + Spec *types.Any +} + +// Exit is the type of exit events. +type Exit struct { + Timestamp time.Time + ID string + Status int +} + +// ProcessMonitor monitors process exit changes. +type ProcessMonitor interface { + // Subscribe to process exit changes + Subscribe() chan runc.Exit + // Unsubscribe to process exit changes + Unsubscribe(c chan runc.Exit) +} diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go new file mode 100644 index 000000000..716de2f59 --- /dev/null +++ b/pkg/shim/v1/proc/utils.go @@ -0,0 +1,90 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "encoding/json" + "io" + "os" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +const ( + internalErrorCode = 128 + bufferSize = 32 +) + +// ExitCh is the exit events channel for containers and exec processes +// inside the sandbox. +var ExitCh = make(chan Exit, bufferSize) + +// TODO(mlaventure): move to runc package? +func getLastRuntimeError(r *runsc.Runsc) (string, error) { + if r.Log == "" { + return "", nil + } + + f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400) + if err != nil { + return "", err + } + + var ( + errMsg string + log struct { + Level string + Msg string + Time time.Time + } + ) + + dec := json.NewDecoder(f) + for err = nil; err == nil; { + if err = dec.Decode(&log); err != nil && err != io.EOF { + return "", err + } + if log.Level == "error" { + errMsg = strings.TrimSpace(log.Msg) + } + } + + return errMsg, nil +} + +func copyFile(to, from string) error { + ff, err := os.Open(from) + if err != nil { + return err + } + defer ff.Close() + tt, err := os.Create(to) + if err != nil { + return err + } + defer tt.Close() + + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + _, err = io.CopyBuffer(tt, ff, *p) + return err +} + +func hasNoIO(r *CreateConfig) bool { + return r.Stdin == "" && r.Stdout == "" && r.Stderr == "" +} diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD new file mode 100644 index 000000000..05c595bc9 --- /dev/null +++ b/pkg/shim/v1/shim/BUILD @@ -0,0 +1,40 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "shim", + srcs = [ + "api.go", + "platform.go", + "service.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", + ], +) diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go new file mode 100644 index 000000000..5dd8ff172 --- /dev/null +++ b/pkg/shim/v1/shim/api.go @@ -0,0 +1,28 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskCreate = events.TaskCreate +type TaskStart = events.TaskStart +type TaskOOM = events.TaskOOM +type TaskExit = events.TaskExit +type TaskDelete = events.TaskDelete +type TaskExecAdded = events.TaskExecAdded +type TaskExecStarted = events.TaskExecStarted diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go new file mode 100644 index 000000000..f590f80ef --- /dev/null +++ b/pkg/shim/v1/shim/platform.go @@ -0,0 +1,106 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *Service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go new file mode 100644 index 000000000..84a810cb2 --- /dev/null +++ b/pkg/shim/v1/shim/service.go @@ -0,0 +1,573 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/containerd/console" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + shim "github.com/containerd/containerd/runtime/v1/shim/v1" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +// Config contains shim specific configuration. +type Config struct { + Path string + Namespace string + WorkDir string + RuntimeRoot string + RunscConfig map[string]string +} + +// NewService returns a new shim service that can be used via GRPC. +func NewService(config Config, publisher events.Publisher) (*Service, error) { + if config.Namespace == "" { + return nil, fmt.Errorf("shim namespace cannot be empty") + } + ctx := namespaces.WithNamespace(context.Background(), config.Namespace) + s := &Service{ + config: config, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + } + go s.processExits() + if err := s.initPlatform(); err != nil { + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// Service is the shim implementation of a remote shim over GRPC. +type Service struct { + mu sync.Mutex + + config Config + context context.Context + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + ec chan proc.Exit + + // Filled by Create() + id string + bundle string +} + +// Create creates a new initial process and container with the underlying OCI runtime. +func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: r.Runtime, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + defer func() { + if err != nil { + if err2 := mount.UnmountAll(rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount") + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + s.config.Path, + s.config.WorkDir, + s.config.RuntimeRoot, + s.config.Namespace, + s.config.RunscConfig, + s.platform, + config, + ) + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + pid := process.Pid() + s.processes[r.ID] = process + return &shim.CreateTaskResponse{ + Pid: uint32(pid), + }, nil +} + +// Start starts a process. +func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + return &shim.StartResponse{ + ID: p.ID(), + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, s.id) + s.mu.Unlock() + s.platform.Close() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// DeleteProcess deletes an exec'd process. +func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) { + if r.ID == s.id { + return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess") + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, r.ID) + s.mu.Unlock() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + + if p := s.processes[r.ID]; p != nil { + s.mu.Unlock() + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID) + } + + p := s.processes[s.id] + s.mu.Unlock() + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + + process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{ + ID: r.ID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resises the terminal of a process. +func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) { + if r.ID == "" { + return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided") + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &shim.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause pauses the container. +func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume resumes the container. +func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill kills a process with the provided signal. +func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) { + if r.ID == "" { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil + } + + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// ListPids returns all pids inside the container. +func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &shim.ListPidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// ShimInfo returns shim information such as the shim's pid. +func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) { + return &shim.ShimInfoResponse{ + ShimPid: uint32(os.Getpid()), + }, nil +} + +// Update updates a running container. +func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + p.Wait() + + return &shim.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *Service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *Service) allProcesses() []process.Process { + s.mu.Lock() + defer s.mu.Unlock() + + res := make([]process.Process, 0, len(s.processes)) + for _, p := range s.processes { + res = append(res, p) + } + return res +} + +func (s *Service) checkProcesses(e proc.Exit) { + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *Service) forward(publisher events.Publisher) { + for e := range s.events { + if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil { + log.G(s.context).WithError(err).Error("post event") + } + } +} + +// getInitProcess returns the init process. +func (s *Service) getInitProcess() (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[s.id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + return p, nil +} + +// getExecProcess returns the given exec process. +func (s *Service) getExecProcess(id string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id) + } + return p, nil +} + +func getTopic(ctx context.Context, e interface{}) string { + switch e.(type) { + case *TaskCreate: + return runtime.TaskCreateEventTopic + case *TaskStart: + return runtime.TaskStartEventTopic + case *TaskOOM: + return runtime.TaskOOMEventTopic + case *TaskExit: + return runtime.TaskExitEventTopic + case *TaskDelete: + return runtime.TaskDeleteEventTopic + case *TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) { + var options runctypes.CreateOptions + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + options = *v.(*runctypes.CreateOptions) + } + + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + + runsc.FormatLogPath(r.ID, config) + rootfs := filepath.Join(path, "rootfs") + runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = utils.IsSandbox(spec) + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD new file mode 100644 index 000000000..54a0aabb7 --- /dev/null +++ b/pkg/shim/v1/utils/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "utils", + srcs = [ + "annotations.go", + "utils.go", + "volumes.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) + +go_test( + name = "utils_test", + size = "small", + srcs = ["volumes_test.go"], + library = ":utils", + deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], +) diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go new file mode 100644 index 000000000..1e9d3f365 --- /dev/null +++ b/pkg/shim/v1/utils/annotations.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +// Annotations from the CRI annotations package. +// +// These are vendor due to import conflicts. +const ( + sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory" + containerTypeAnnotation = "io.kubernetes.cri.container-type" + containerTypeSandbox = "sandbox" + containerTypeContainer = "container" +) diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go new file mode 100644 index 000000000..07e346654 --- /dev/null +++ b/pkg/shim/v1/utils/utils.go @@ -0,0 +1,56 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "io/ioutil" + "os" + "path/filepath" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +// ReadSpec reads OCI spec from the bundle directory. +func ReadSpec(bundle string) (*specs.Spec, error) { + f, err := os.Open(filepath.Join(bundle, "config.json")) + if err != nil { + return nil, err + } + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + return nil, err + } + return &spec, nil +} + +// IsSandbox checks whether a container is a sandbox container. +func IsSandbox(spec *specs.Spec) bool { + t, ok := spec.Annotations[containerTypeAnnotation] + return !ok || t == containerTypeSandbox +} + +// UserLogPath gets user log path from OCI annotation. +func UserLogPath(spec *specs.Spec) string { + sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "" + } + return filepath.Join(sandboxLogDir, "gvisor.log") +} diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go new file mode 100644 index 000000000..52a428179 --- /dev/null +++ b/pkg/shim/v1/utils/volumes.go @@ -0,0 +1,155 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const volumeKeyPrefix = "dev.gvisor.spec.mount." + +var kubeletPodsDir = "/var/lib/kubelet/pods" + +// volumeName gets volume name from volume annotation key, example: +// dev.gvisor.spec.mount.NAME.share +func volumeName(k string) string { + return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0] +} + +// volumeFieldName gets volume field name from volume annotation key, example: +// `type` is the field of dev.gvisor.spec.mount.NAME.type +func volumeFieldName(k string) string { + parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".") + return parts[len(parts)-1] +} + +// podUID gets pod UID from the pod log path. +func podUID(s *specs.Spec) (string, error) { + sandboxLogDir := s.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "", fmt.Errorf("no sandbox log path annotation") + } + fields := strings.Split(filepath.Base(sandboxLogDir), "_") + switch len(fields) { + case 1: // This is the old CRI logging path. + return fields[0], nil + case 3: // This is the new CRI logging path. + return fields[2], nil + } + return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir) +} + +// isVolumeKey checks whether an annotation key is for volume. +func isVolumeKey(k string) bool { + return strings.HasPrefix(k, volumeKeyPrefix) +} + +// volumeSourceKey constructs the annotation key for volume source. +func volumeSourceKey(volume string) string { + return volumeKeyPrefix + volume + ".source" +} + +// volumePath searches the volume path in the kubelet pod directory. +func volumePath(volume, uid string) (string, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume) + dirs, err := filepath.Glob(volumeSearchPath) + if err != nil { + return "", err + } + if len(dirs) != 1 { + return "", fmt.Errorf("unexpected matched volume list %v", dirs) + } + return dirs[0], nil +} + +// isVolumePath checks whether a string is the volume path. +func isVolumePath(volume, path string) (bool, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume) + return filepath.Match(volumeSearchPath, path) +} + +// UpdateVolumeAnnotations add necessary OCI annotations for gvisor +// volume optimization. +func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { + var ( + uid string + err error + ) + if IsSandbox(s) { + uid, err = podUID(s) + if err != nil { + // Skip if we can't get pod UID, because this doesn't work + // for containerd 1.1. + return nil + } + } + var updated bool + for k, v := range s.Annotations { + if !isVolumeKey(k) { + continue + } + if volumeFieldName(k) != "type" { + continue + } + volume := volumeName(k) + if uid != "" { + // This is a sandbox. + path, err := volumePath(volume, uid) + if err != nil { + return fmt.Errorf("get volume path for %q: %w", volume, err) + } + s.Annotations[volumeSourceKey(volume)] = path + updated = true + } else { + // This is a container. + for i := range s.Mounts { + // An error is returned for sandbox if source + // annotation is not successfully applied, so + // it is guaranteed that the source annotation + // for sandbox has already been successfully + // applied at this point. + // + // The volume name is unique inside a pod, so + // matching without podUID is fine here. + // + // TODO: Pass podUID down to shim for containers to do + // more accurate matching. + if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { + // gVisor requires the container mount type to match + // sandbox mount type. + s.Mounts[i].Type = v + updated = true + } + } + } + } + if !updated { + return nil + } + // Update bundle. + b, err := json.Marshal(s) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) +} diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go new file mode 100644 index 000000000..3e02c6151 --- /dev/null +++ b/pkg/shim/v1/utils/volumes_test.go @@ -0,0 +1,308 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +func TestUpdateVolumeAnnotations(t *testing.T) { + dir, err := ioutil.TempDir("", "test-update-volume-annotations") + if err != nil { + t.Fatalf("create tempdir: %v", err) + } + defer os.RemoveAll(dir) + kubeletPodsDir = dir + + const ( + testPodUID = "testuid" + testVolumeName = "testvolume" + testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID + testLegacyLogDirPath = "/var/log/pods/" + testPodUID + ) + testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName) + + if err := os.MkdirAll(testVolumePath, 0755); err != nil { + t.Fatalf("Create test volume: %v", err) + } + + for _, test := range []struct { + desc string + spec *specs.Spec + expected *specs.Spec + expectErr bool + expectUpdate bool + }{ + { + desc: "volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "volume annotations for sandbox with legacy log path", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "tmpfs: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "bind: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "should not return error without pod log directory", + spec: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + }, + { + desc: "should return error if volume path does not exist", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount.notexist.share": "pod", + "dev.gvisor.spec.mount.notexist.type": "tmpfs", + "dev.gvisor.spec.mount.notexist.options": "ro", + }, + }, + expectErr: true, + }, + { + desc: "no volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + }, + { + desc: "no volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + bundle, err := ioutil.TempDir(dir, "test-bundle") + if err != nil { + t.Fatalf("Create test bundle: %v", err) + } + err = UpdateVolumeAnnotations(bundle, test.spec) + if test.expectErr { + if err == nil { + t.Fatal("Expected error, but got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(test.expected, test.spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) + } + if test.expectUpdate { + b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json")) + if err != nil { + t.Fatalf("Read spec from bundle: %v", err) + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + t.Fatalf("Unmarshal spec: %v", err) + } + if !reflect.DeepEqual(test.expected, &spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, &spec) + } + } + }) + } +} diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD new file mode 100644 index 000000000..7e0a114a0 --- /dev/null +++ b/pkg/shim/v2/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "v2", + srcs = [ + "api.go", + "epoll.go", + "service.go", + "service_linux.go", + ], + visibility = ["//shim:__subpackages__"], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "//pkg/shim/v2/options", + "//pkg/shim/v2/runtimeoptions", + "//runsc/specutils", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + "@com_github_containerd_containerd//runtime/v2/task:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go new file mode 100644 index 000000000..dbe5c59f6 --- /dev/null +++ b/pkg/shim/v2/api.go @@ -0,0 +1,22 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskOOM = events.TaskOOM diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go new file mode 100644 index 000000000..41232cca8 --- /dev/null +++ b/pkg/shim/v2/epoll.go @@ -0,0 +1,129 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "sync" + + "github.com/containerd/cgroups" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/runtime" + "golang.org/x/sys/unix" +) + +func newOOMEpoller(publisher events.Publisher) (*epoller, error) { + fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + return &epoller{ + fd: fd, + publisher: publisher, + set: make(map[uintptr]*item), + }, nil +} + +type epoller struct { + mu sync.Mutex + + fd int + publisher events.Publisher + set map[uintptr]*item +} + +type item struct { + id string + cg cgroups.Cgroup +} + +func (e *epoller) Close() error { + return unix.Close(e.fd) +} + +func (e *epoller) run(ctx context.Context) { + var events [128]unix.EpollEvent + for { + select { + case <-ctx.Done(): + e.Close() + return + default: + n, err := unix.EpollWait(e.fd, events[:], -1) + if err != nil { + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + // Should not happen. + panic(fmt.Errorf("cgroups: epoll wait: %w", err)) + } + for i := 0; i < n; i++ { + e.process(ctx, uintptr(events[i].Fd)) + } + } + } +} + +func (e *epoller) add(id string, cg cgroups.Cgroup) error { + e.mu.Lock() + defer e.mu.Unlock() + fd, err := cg.OOMEventFD() + if err != nil { + return err + } + e.set[fd] = &item{ + id: id, + cg: cg, + } + event := unix.EpollEvent{ + Fd: int32(fd), + Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR, + } + return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event) +} + +func (e *epoller) process(ctx context.Context, fd uintptr) { + flush(fd) + e.mu.Lock() + i, ok := e.set[fd] + if !ok { + e.mu.Unlock() + return + } + e.mu.Unlock() + if i.cg.State() == cgroups.Deleted { + e.mu.Lock() + delete(e.set, fd) + e.mu.Unlock() + unix.Close(int(fd)) + return + } + if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{ + ContainerID: i.id, + }); err != nil { + // Should not happen. + panic(fmt.Errorf("publish OOM event: %w", err)) + } +} + +func flush(fd uintptr) error { + var buf [8]byte + _, err := unix.Read(int(fd), buf[:]) + return err +} diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD new file mode 100644 index 000000000..ca212e874 --- /dev/null +++ b/pkg/shim/v2/options/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "options", + srcs = [ + "options.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go new file mode 100644 index 000000000..de09f2f79 --- /dev/null +++ b/pkg/shim/v2/options/options.go @@ -0,0 +1,33 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +const OptionType = "io.containerd.runsc.v1.options" + +// Options is runtime options for io.containerd.runsc.v1. +type Options struct { + // ShimCgroup is the cgroup the shim should be in. + ShimCgroup string `toml:"shim_cgroup"` + // IoUid is the I/O's pipes uid. + IoUid uint32 `toml:"io_uid"` + // IoUid is the I/O's pipes gid. + IoGid uint32 `toml:"io_gid"` + // BinaryName is the binary name of the runsc binary. + BinaryName string `toml:"binary_name"` + // Root is the runsc root directory. + Root string `toml:"root"` + // RunscConfig is a key/value map of all runsc flags. + RunscConfig map[string]string `toml:"runsc_config"` +} diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD new file mode 100644 index 000000000..01716034c --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "proto_library") + +package(licenses = ["notice"]) + +proto_library( + name = "api", + srcs = [ + "runtimeoptions.proto", + ], +) + +go_library( + name = "runtimeoptions", + srcs = ["runtimeoptions.go"], + visibility = ["//pkg/shim/v2:__pkg__"], + deps = [ + "//pkg/shim/v2/runtimeoptions:api_go_proto", + "@com_github_gogo_protobuf//proto:go_default_library", + ], +) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go new file mode 100644 index 000000000..1c1a0c5d1 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -0,0 +1,27 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + proto "github.com/gogo/protobuf/proto" + pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto" +) + +type Options = pb.Options + +func init() { + proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto new file mode 100644 index 000000000..edb19020a --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package runtimeoptions; + +// This is a version of the runtimeoptions CRI API that is vendored. +// +// Imported the full CRI package is a nightmare. +message Options { + string type_url = 1; + string config_path = 2; +} diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go new file mode 100644 index 000000000..1534152fc --- /dev/null +++ b/pkg/shim/v2/service.go @@ -0,0 +1,824 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/BurntSushi/toml" + "github.com/containerd/cgroups" + "github.com/containerd/console" + "github.com/containerd/containerd/api/events" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + "github.com/containerd/containerd/runtime/v2/shim" + taskAPI "github.com/containerd/containerd/runtime/v2/task" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" + "gvisor.dev/gvisor/pkg/shim/v2/options" + "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions" + "gvisor.dev/gvisor/runsc/specutils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +var _ = (taskAPI.TaskService)(&service{}) + +// configFile is the default config file name. For containerd 1.2, +// we assume that a config.toml should exist in the runtime root. +const configFile = "config.toml" + +// New returns a new shim service that can be used via GRPC. +func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { + ep, err := newOOMEpoller(publisher) + if err != nil { + return nil, err + } + go ep.run(ctx) + s := &service{ + id: id, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + oomPoller: ep, + cancel: cancel, + } + go s.processExits() + runsc.Monitor = reaper.Default + if err := s.initPlatform(); err != nil { + cancel() + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// service is the shim implementation of a remote shim over GRPC. +type service struct { + mu sync.Mutex + + context context.Context + task process.Process + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + opts options.Options + ec chan proc.Exit + oomPoller *epoller + + id string + bundle string + cancel func() +} + +func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + self, err := os.Executable() + if err != nil { + return nil, err + } + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + args := []string{ + "-namespace", ns, + "-address", containerdAddress, + "-publish-binary", containerdBinary, + } + cmd := exec.Command(self, args...) + cmd.Dir = cwd + cmd.Env = append(os.Environ(), "GOMAXPROCS=2") + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + return cmd, nil +} + +func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) { + cmd, err := newCommand(ctx, containerdBinary, containerdAddress) + if err != nil { + return "", err + } + address, err := shim.SocketAddress(ctx, id) + if err != nil { + return "", err + } + socket, err := shim.NewSocket(address) + if err != nil { + return "", err + } + defer socket.Close() + f, err := socket.File() + if err != nil { + return "", err + } + defer f.Close() + + cmd.ExtraFiles = append(cmd.ExtraFiles, f) + + if err := cmd.Start(); err != nil { + return "", err + } + defer func() { + if err != nil { + cmd.Process.Kill() + } + }() + // make sure to wait after start + go cmd.Wait() + if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil { + return "", err + } + if err := shim.WriteAddress("address", address); err != nil { + return "", err + } + if err := shim.SetScore(cmd.Process.Pid); err != nil { + return "", fmt.Errorf("failed to set OOM Score on shim: %w", err) + } + return address, nil +} + +func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{ + Force: true, + }); err != nil { + log.L.Printf("failed to remove runc container: %v", err) + } + if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + return &taskAPI.DeleteResponse{ + ExitedAt: time.Now(), + ExitStatus: 128 + uint32(unix.SIGKILL), + }, nil +} + +func (s *service) readRuntime(path string) (string, error) { + data, err := ioutil.ReadFile(filepath.Join(path, "runtime")) + if err != nil { + return "", err + } + return string(data), nil +} + +func (s *service) writeRuntime(path, runtime string) error { + return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600) +} + +// Create creates a new initial process and container with the underlying OCI +// runtime. +func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, fmt.Errorf("create namespace: %w", err) + } + + // Read from root for now. + var opts options.Options + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + var path string + switch o := v.(type) { + case *runctypes.CreateOptions: // containerd 1.2.x + opts.IoUid = o.IoUid + opts.IoGid = o.IoGid + opts.ShimCgroup = o.ShimCgroup + case *runctypes.RuncOptions: // containerd 1.2.x + root := proc.RunscRoot + if o.RuntimeRoot != "" { + root = o.RuntimeRoot + } + + opts.BinaryName = o.Runtime + + path = filepath.Join(root, configFile) + if _, err := os.Stat(path); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("stat config file %q: %w", path, err) + } + // A config file in runtime root is not required. + path = "" + } + case *runtimeoptions.Options: // containerd 1.3.x+ + if o.ConfigPath == "" { + break + } + if o.TypeUrl != options.OptionType { + return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl) + } + path = o.ConfigPath + default: + return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl) + } + if path != "" { + if _, err = toml.DecodeFile(path, &opts); err != nil { + return nil, fmt.Errorf("decode config file %q: %w", path, err) + } + } + } + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: opts.BinaryName, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil { + return nil, err + } + defer func() { + if err != nil { + if err := mount.UnmountAll(rootfs, 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + r.Bundle, + filepath.Join(r.Bundle, "work"), + ns, + s.platform, + config, + &opts, + rootfs, + ) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + + // Set up OOM notification on the sandbox's cgroup. This is done on + // sandbox create since the sandbox process will be created here. + pid := process.Pid() + if pid > 0 { + cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid)) + if err != nil { + return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err) + } + if err := s.oomPoller.add(s.id, cg); err != nil { + return nil, fmt.Errorf("add cg to OOM monitor: %w", err) + } + } + s.task = process + s.opts = opts + return &taskAPI.CreateTaskResponse{ + Pid: uint32(process.Pid()), + }, nil + +} + +// Start starts a process. +func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + // TODO: Set the cgroup and oom notifications on restore. + // https://github.com/google/gvisor-containerd-shim/issues/58 + return &taskAPI.StartResponse{ + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + isTask := r.ExecID == "" + if !isTask { + s.mu.Lock() + delete(s.processes, r.ExecID) + s.mu.Unlock() + } + if isTask && s.platform != nil { + s.platform.Close() + } + return &taskAPI.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + p := s.processes[r.ExecID] + s.mu.Unlock() + if p != nil { + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) + } + p = s.task + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{ + ID: r.ExecID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ExecID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resizes the terminal of a process. +func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &taskAPI.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause the container. +func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume the container. +func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill a process with the provided signal. +func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// Pids returns all pids inside the container. +func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &taskAPI.PidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Connect returns shim information such as the shim's pid. +func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + var pid int + if s.task != nil { + pid = s.task.Pid() + } + return &taskAPI.ConnectResponse{ + ShimPid: uint32(os.Getpid()), + TaskPid: uint32(pid), + }, nil +} + +func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { + s.cancel() + os.Exit(0) + return empty, nil +} + +func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + stats, err := rs.Stats(ctx, s.id) + if err != nil { + return nil, err + } + + // gvisor currently (as of 2020-03-03) only returns the total memory + // usage and current PID value[0]. However, we copy the common fields here + // so that future updates will propagate correct information. We're + // using the cgroups.Metrics structure so we're returning the same type + // as runc. + // + // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 + data, err := typeurl.MarshalAny(&cgroups.Metrics{ + CPU: &cgroups.CPUStat{ + Usage: &cgroups.CPUUsage{ + Total: stats.Cpu.Usage.Total, + Kernel: stats.Cpu.Usage.Kernel, + User: stats.Cpu.Usage.User, + PerCPU: stats.Cpu.Usage.Percpu, + }, + Throttling: &cgroups.Throttle{ + Periods: stats.Cpu.Throttling.Periods, + ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, + ThrottledTime: stats.Cpu.Throttling.ThrottledTime, + }, + }, + Memory: &cgroups.MemoryStat{ + Cache: stats.Memory.Cache, + Usage: &cgroups.MemoryEntry{ + Limit: stats.Memory.Usage.Limit, + Usage: stats.Memory.Usage.Usage, + Max: stats.Memory.Usage.Max, + Failcnt: stats.Memory.Usage.Failcnt, + }, + Swap: &cgroups.MemoryEntry{ + Limit: stats.Memory.Swap.Limit, + Usage: stats.Memory.Swap.Usage, + Max: stats.Memory.Swap.Max, + Failcnt: stats.Memory.Swap.Failcnt, + }, + Kernel: &cgroups.MemoryEntry{ + Limit: stats.Memory.Kernel.Limit, + Usage: stats.Memory.Kernel.Usage, + Max: stats.Memory.Kernel.Max, + Failcnt: stats.Memory.Kernel.Failcnt, + }, + KernelTCP: &cgroups.MemoryEntry{ + Limit: stats.Memory.KernelTCP.Limit, + Usage: stats.Memory.KernelTCP.Usage, + Max: stats.Memory.KernelTCP.Max, + Failcnt: stats.Memory.KernelTCP.Failcnt, + }, + }, + Pids: &cgroups.PidsStat{ + Current: stats.Pids.Current, + Limit: stats.Pids.Limit, + }, + }) + if err != nil { + return nil, err + } + return &taskAPI.StatsResponse{ + Stats: data, + }, nil +} + +// Update updates a running container. +func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + p.Wait() + + return &taskAPI.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *service) checkProcesses(e proc.Exit) { + // TODO(random-liu): Add `shouldKillAll` logic if container pid + // namespace is supported. + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &events.TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *service) allProcesses() (o []process.Process) { + s.mu.Lock() + defer s.mu.Unlock() + for _, p := range s.processes { + o = append(o, p) + } + if s.task != nil { + o = append(o, s.task) + } + return o +} + +func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + s.mu.Lock() + p := s.task + s.mu.Unlock() + if p == nil { + return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) + } + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *service) forward(publisher shim.Publisher) { + for e := range s.events { + ctx, cancel := context.WithTimeout(s.context, 5*time.Second) + err := publisher.Publish(ctx, getTopic(e), e) + cancel() + if err != nil { + // Should not happen. + panic(fmt.Errorf("post event: %w", err)) + } + } +} + +func (s *service) getProcess(execID string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + if execID == "" { + return s.task, nil + } + p := s.processes[execID] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) + } + return p, nil +} + +func getTopic(e interface{}) string { + switch e.(type) { + case *events.TaskCreate: + return runtime.TaskCreateEventTopic + case *events.TaskStart: + return runtime.TaskStartEventTopic + case *events.TaskOOM: + return runtime.TaskOOMEventTopic + case *events.TaskExit: + return runtime.TaskExitEventTopic + case *events.TaskDelete: + return runtime.TaskDeleteEventTopic + case *events.TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *events.TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) { + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + runsc.FormatLogPath(r.ID, options.RunscConfig) + runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go new file mode 100644 index 000000000..1800ab90b --- /dev/null +++ b/pkg/shim/v2/service_linux.go @@ -0,0 +1,108 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index e131455f7..ae0fe1522 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -12,6 +12,7 @@ go_library( "sleep_unsafe.go", ], visibility = ["//:sandbox"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index af47e2ba1..1dd11707d 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -379,10 +379,7 @@ func TestRace(t *testing.T) { // TestRaceInOrder tests that multiple wakers can continuously send wake requests to // the sleeper and that the wakers are retrieved in the order asserted. func TestRaceInOrder(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - w := make([]Waker, wakers) + w := make([]Waker, 10000) s := Sleeper{} // Associate each waker and start goroutines that will assert them. @@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) { s.AddWaker(&w[i], i) } go func() { - n := 0 - for n < wakeRequests { - wk := w[n%len(w)] - wk.Assert() - n++ + for i := range w { + w[i].Assert() } }() // Wait for all wake up notifications from all wakers. - for i := 0; i < wakeRequests; i++ { - v, _ := s.Fetch(true) - if got, want := v, i%wakers; got != want { - t.Fatalf("got %d want %d", got, want) + for want := range w { + got, _ := s.Fetch(true) + if got != want { + t.Fatalf("got %d want %d", got, want) } } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index f68c12620..118805492 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -75,6 +75,8 @@ package sleep import ( "sync/atomic" "unsafe" + + "gvisor.dev/gvisor/pkg/sync" ) const ( @@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { // // This struct is thread-safe, that is, its methods can be called concurrently // by multiple goroutines. +// +// Note, it is not safe to copy a Waker as its fields are modified by value +// (the pointer fields are individually modified with atomic operations). type Waker struct { + _ sync.NoCopy + // s is the sleeper that this waker can wake up. Only one sleeper at a // time is allowed. This field can have three classes of values: // nil -- the waker is not asserted: it either is not associated with diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index d0d77e19c..4d47207f7 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -33,6 +33,7 @@ go_library( "aliases.go", "memmove_unsafe.go", "mutex_unsafe.go", + "nocopy.go", "norace_unsafe.go", "race_unsafe.go", "rwmutex_unsafe.go", diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go new file mode 100644 index 000000000..722b29501 --- /dev/null +++ b/pkg/sync/nocopy.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +// NoCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type NoCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Unlock() {} diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 8ff922c69..5ae10939d 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -22,7 +22,7 @@ import ( // Mapping for tcpip.Error types. var ( ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL) + ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0cde694dc..d87797617 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -48,7 +48,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -64,6 +64,6 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index b1e92d2d7..eaface8cb 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -53,6 +53,10 @@ const ( // (all bits set to 0). unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + // EthernetBroadcastAddress is an ethernet address that addresses every node + // on a local link. + EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff") + // unicastMulticastFlagMask is the mask of the least significant bit in // the first octet (in network byte order) of an ethernet address that // determines whether the ethernet address is a unicast or multicast. If diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 7908c5744..1a631b31a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -72,6 +72,7 @@ const ( // Values for ICMP code as defined in RFC 792. const ( ICMPv4TTLExceeded = 0 + ICMPv4HostUnreachable = 1 ICMPv4PortUnreachable = 3 ICMPv4FragmentationNeeded = 4 ) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index c7ee2de57..a13b4b809 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -110,9 +110,16 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP code as defined in RFC 4443. +// Values for ICMP destination unreachable code as defined in RFC 4443 section +// 3.1. const ( - ICMPv6PortUnreachable = 4 + ICMPv6NetworkUnreachable = 0 + ICMPv6Prohibited = 1 + ICMPv6BeyondScope = 2 + ICMPv6AddressUnreachable = 3 + ICMPv6PortUnreachable = 4 + ICMPv6Policy = 5 + ICMPv6RejectRoute = 6 ) // Type is the ICMP type field. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index b8b93e78e..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 20b183da0..e12a5929b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -296,3 +297,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { e.q.RemoveNotify(handle) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index aa6db9aea..507b44abc 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/binary", + "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f34082e1a..c18bb91fb 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,6 +45,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -385,26 +386,35 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -430,47 +440,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + builder.Add(vnetHdrBuf) } - if pkt.Data.Size() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Header.View()) - } - if pkt.Header.UsedLength() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + builder.Add(pkt.Header.View()) + for _, v := range pkt.Data.Views() { + builder.Add(v) } - return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) + return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { - var ethHdrBuf []byte - iovLen := 0 if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - } - - // Preserve the src address if it's set in the route. - if pkt.EgressRoute.LocalLinkAddress != "" { - ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } - vnetHdr := virtioNetHdr{} var vnetHdrBuf []byte if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if pkt.GSOOptions.NeedsCsum { @@ -491,45 +482,19 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc } } vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - iovLen++ } - iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + var builder iovec.Builder + builder.Add(vnetHdrBuf) + builder.Add(pkt.Header.View()) + for _, v := range pkt.Data.Views() { + builder.Add(v) + } + iovecs := builder.Build() + var mmsgHdr rawfile.MMsgHdr mmsgHdr.Msg.Iov = &iovecs[0] - iovecIdx := 0 - if vnetHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = &vnetHdrBuf[0] - v.Len = uint64(len(vnetHdrBuf)) - iovecIdx++ - } - if ethHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = ðHdrBuf[0] - v.Len = uint64(len(ethHdrBuf)) - iovecIdx++ - } - pktSize := uint64(0) - // Encode L3 Header - v := &iovecs[iovecIdx] - hdr := &pkt.Header - hdrView := hdr.View() - v.Base = &hdrView[0] - v.Len = uint64(len(hdrView)) - pktSize += v.Len - iovecIdx++ - - // Now encode the Transport Payload. - pktViews := pkt.Data.Views() - for i := range pktViews { - vec := &iovecs[iovecIdx] - iovecIdx++ - vec.Base = &pktViews[i][0] - vec.Len = uint64(len(pktViews[i])) - pktSize += vec.Len - } - mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + mmsgHdr.Msg.Iovlen = uint64(len(iovecs)) mmsgHdrs = append(mmsgHdrs, mmsgHdr) } @@ -626,6 +591,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index eaee7e5d7..7b995b85a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin c.ch <- packetInfo{remote, protocol, pkt} } +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNoEthernetProperties(t *testing.T) { c := newContext(t, &Options{MTU: mtu}) defer c.cleanup() @@ -500,3 +504,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) { } } + +// fakeNetworkDispatcher delivers packets to pkts. +type fakeNetworkDispatcher struct { + pkts []*stack.PacketBuffer +} + +func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + d.pkts = append(d.pkts, pkt) +} + +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestDispatchPacketFormat(t *testing.T) { + for _, test := range []struct { + name string + newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) + }{ + { + name: "readVDispatcher", + newDispatcher: newReadVDispatcher, + }, + { + name: "recvMMsgDispatcher", + newDispatcher: newRecvMMsgDispatcher, + }, + } { + t.Run(test.name, func(t *testing.T) { + // Create a socket pair to send/recv. + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Fatal(err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + data := []byte{ + // Ethernet header. + 1, 2, 3, 4, 5, 60, + 1, 2, 3, 4, 5, 61, + 8, 0, + // Mock network header. + 40, 41, 42, 43, + } + err = syscall.Sendmsg(fds[1], data, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + + // Create and run dispatcher once. + sink := &fakeNetworkDispatcher{} + d, err := test.newDispatcher(fds[0], &endpoint{ + hdrSize: header.EthernetMinimumSize, + dispatcher: sink, + }) + if err != nil { + t.Fatal(err) + } + if ok, err := d.dispatch(); !ok || err != nil { + t.Fatalf("d.dispatch() = %v, %v", ok, err) + } + + // Verify packet. + if got, want := len(sink.pkts), 1; got != want { + t.Fatalf("len(sink.pkts) = %d, want %d", got, want) + } + pkt := sink.pkts[0] + if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { + t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + } + if got, want := pkt.Data.Size(), 4; got != want { + t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index f04738cfb..d8f2504b3 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -278,7 +278,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 568c6874f..781cdd317 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -113,3 +113,11 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 82b441b79..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c69d6b7e9..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index bdd5276ad..2cdb23475 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 2998f9c4f..d40de54df 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -60,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco } } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.mu.Lock() @@ -129,3 +140,13 @@ func (e *Endpoint) GSOMaxSize() uint32 { } return 0 } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index c1a219f02..7d9249c1c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd d.count++ } +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNestedLinkEndpoint(t *testing.T) { const emptyAddress = tcpip.LinkAddress("") diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 054c213bc..1d0079bd6 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b5dfb7850..467083239 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher @@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } @@ -207,3 +215,13 @@ func (e *endpoint) Wait() { e.wg.Wait() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..f4c32c2da 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { return nil } -// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a -// single syscall. It fails if partial data is written. -func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { - // If the is no second buffer, issue a regular write. - if len(b2) == 0 { - return NonBlockingWrite(fd, b1) - } - - // We have two buffers. Build the iovec that represents them and issue - // a writev syscall. - iovec := [3]syscall.Iovec{ - { - Base: &b1[0], - Len: uint64(len(b1)), - }, - { - Base: &b2[0], - Len: uint64(len(b2)), - }, - } - iovecLen := uintptr(2) - - if len(b3) > 0 { - iovecLen++ - iovec[2].Base = &b3[0] - iovec[2].Len = uint64(len(b3)) - } - +// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. +// It fails if partial data is written. +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { + iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { return TranslateErrno(e) } - return nil } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0374a2441..507c76b76 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) v := pkt.Data.ToView() // Transmit the packet. @@ -287,3 +294,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 28a2e88ba..8f3cd9449 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index d9cd4e83a..509076643 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) +} + func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6bc9033d0..04ae58e59 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE stack: s, nicID: id, name: name, + isTap: prefix == "tap", } endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { @@ -271,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader == nil { - hdr := &header.EthernetFields{ - SrcAddr: info.Route.LocalLinkAddress, - DstAddr: info.Route.RemoteLinkAddress, - Type: info.Proto, - } - if hdr.SrcAddr == "" { - hdr.SrcAddr = d.endpoint.LinkAddress() - } - - eth := make(header.Ethernet, header.EthernetMinimumSize) - eth.Encode(hdr) - vv.AppendView(buffer.View(eth)) - } else { - vv.AppendView(info.Pkt.LinkHeader) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } + vv.AppendView(info.Pkt.LinkHeader) } // Append upper headers. @@ -348,6 +337,7 @@ type tunEndpoint struct { stack *stack.Stack nicID tcpip.NICID name string + isTap bool } // DecRef decrements refcount of e, removes NIC if refcount goes to 0. @@ -356,3 +346,38 @@ func (e *tunEndpoint) DecRef() { e.stack.RemoveNIC(e.nicID) }) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0956d2c65..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,6 +26,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 949b3f2b2..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 63bf40562..c448a888f 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -81,9 +86,19 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7f27a840d..b0f57040c 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -162,7 +162,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: header.EthernetBroadcastAddress, } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) @@ -181,7 +181,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { - return broadcastMAC, true + return header.EthernetBroadcastAddress, true } if header.IsV4MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv4Address(addr), true @@ -216,8 +216,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return 0, false, true } -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - // NewProtocol returns an ARP network protocol. func NewProtocol() stack.NetworkProtocol { return &protocol{} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7c8fb3e0a..615bae648 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,14 +172,24 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 78420d6e6..d142b4ffa 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -34,6 +34,6 @@ go_test( "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1b67aa066..83e71cb8c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7e9f16c90..b1776e5ee 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -225,12 +225,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -376,13 +374,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 3f71fc520..feada63dc 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -39,6 +39,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2ff7eedf4..ff1cb53dd 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } @@ -494,8 +496,6 @@ const ( icmpV6LengthOffset = 25 ) -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 794ddb5c8..6b9a6b316 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -27,6 +27,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,6 +47,7 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", @@ -50,6 +63,7 @@ go_library( "stack_global_state.go", "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ @@ -79,6 +93,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", "//pkg/rand", @@ -94,7 +109,7 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index af9c325ca..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,9 +15,12 @@ package stack import ( + "encoding/binary" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) @@ -30,6 +33,10 @@ import ( // // Currently, only TCP tracking is supported. +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 + // Direction of the tuple. type direction int @@ -42,13 +49,19 @@ const ( type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. +// +// +stateify savable type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry + tupleID // conn is the connection tracking entry this tuple belongs to. @@ -61,6 +74,8 @@ type tuple struct { // tupleID uniquely identifies a connection in one direction. It currently // contains enough information to distinguish between any TCP or UDP // connection, and will need to be extended to support other protocols. +// +// +stateify savable type tupleID struct { srcAddr tcpip.Address srcPort uint16 @@ -83,6 +98,8 @@ func (ti tupleID) reply() tupleID { } // conn is a tracked connection. +// +// +stateify savable type conn struct { // original is the tuple in original direction. It is immutable. original tuple @@ -97,23 +114,84 @@ type conn struct { // update the state of tcb. It is immutable. tcbHook Hook - // mu protects tcb. - mu sync.Mutex - + // mu protects all mutable state. + mu sync.Mutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} + +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout +} + +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } } // ConnTrack tracks all connections created for NAT rules. Most users are -// expected to only call handlePacket and createConnFor. +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable type ConnTrack struct { - // mu protects conns. - mu sync.RWMutex + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 - // conns maintains a map of tuples needed for connection tracking for - // iptables NAT rules. It is protected by mu. - conns map[tupleID]tuple + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket +} + +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList } // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid @@ -143,8 +221,9 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // newConn creates new connection. func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { conn := conn{ - manip: manip, - tcbHook: hook, + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), } conn.original = tuple{conn: &conn, tupleID: orig} conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} @@ -162,18 +241,31 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } - ct.mu.Lock() - defer ct.mu.Unlock() - - tuple, ok := ct.conns[tid] - if !ok { - return nil, dirOriginal + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } } - return tuple.conn, tuple.direction + + return nil, dirOriginal } -// createConnFor creates a new conn for pkt. -func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -196,21 +288,59 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg manip = manipDstOutput } conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} - // Add the changed tuple to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked - // list. - ct.mu.Lock() - defer ct.mu.Unlock() - ct.conns[tid] = conn.original - ct.conns[replyTID] = conn.reply +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() + } - return conn + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } + + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. // TODO(gvisor.dev/issue/170): Change address for Prerouting hook. func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -228,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -268,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d } // handlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. if conn == nil { - // Connection not found for the packet or the packet is invalid. - return + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return false } switch hook { @@ -297,35 +448,159 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() - var st tcpconntrack.Result - tcpHeader := header.TCP(pkt.TransportHeader) - if conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return } - // Delete conn if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConn(conn) + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return } -} -// deleteConn deletes the connection. -func (ct *ConnTrack) deleteConn(conn *conn) { - if conn == nil { + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} + +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval + } + // We've hit the maximum interval. + return idx, maxInterval +} - ct.mu.Lock() - defer ct.mu.Unlock() +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false + } + + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } - delete(ct.conns, conn.original.tupleID) - delete(ct.conns, conn.reply.tupleID) + return true } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index a6546cef0..bca1d940b 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -301,6 +302,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 974d77c36..cbbae4224 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,39 +16,49 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) -// Table names. +// tableID is an index into IPTables.tables. +type tableID int + const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" + natID tableID = iota + mangleID + filterID + numTables ) -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +// Table names. const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" + NATTable = "nat" + MangleTable = "mangle" + FilterTable = "filter" ) +// nameToID is immutable. +var nameToID = map[string]tableID{ + NATTable: natID, + MangleTable: mangleID, + FilterTable: filterID, +} + // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. return &IPTables{ - tables: map[string]Table{ - TablenameNat: Table{ + tables: [numTables]Table{ + natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -56,64 +66,71 @@ func DefaultTables() *IPTables { Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - Underflows: map[Hook]int{ + Underflows: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - UserChains: map[string]int{}, }, - TablenameMangle: Table{ + mangleID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Output: 1, }, - Underflows: map[Hook]int{ - Prerouting: 0, - Output: 1, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, - TablenameFilter: Table{ + filterID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, }, - priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + priorities: [NumHooks][]tableID{ + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, connections: ConnTrack{ - conns: make(map[tupleID]tuple), + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -122,62 +139,59 @@ func DefaultTables() *IPTables { func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, } } -// EmptyNatTable returns a Table with no rules and the filter table chains +// EmptyNATTable returns a Table with no rules and the filter table chains // mapped to HookUnset. -func EmptyNatTable() Table { +func EmptyNATTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + BuiltinChains: [NumHooks]int{ + Forward: HookUnset, }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + Underflows: [NumHooks]int{ + Forward: HookUnset, }, - UserChains: map[string]int{}, } } -// GetTable returns table by name. +// GetTable returns a table by name. func (it *IPTables) GetTable(name string) (Table, bool) { + id, ok := nameToID[name] + if !ok { + return Table{}, false + } it.mu.RLock() defer it.mu.RUnlock() - t, ok := it.tables[name] - return t, ok + return it.tables[id], true } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) { +func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { + id, ok := nameToID[name] + if !ok { + return tcpip.ErrInvalidOptionValue + } it.mu.Lock() defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } it.modified = true - it.tables[name] = table -} - -// GetPriorities returns slice of priorities associated with hook. -func (it *IPTables) GetPriorities(hook Hook) []string { - it.mu.RLock() - defer it.mu.RUnlock() - return it.priorities[hook] + it.tables[id] = table + return nil } // A chainVerdict is what a table decides should be done with a packet. @@ -212,11 +226,19 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.GetPriorities(hook) { - table, _ := it.GetTable(tablename) + it.mu.RLock() + defer it.mu.RUnlock() + priorities := it.priorities[hook] + for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } + table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -245,17 +267,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -279,9 +343,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -326,23 +390,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d43f60c67..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index c528ec381..73274ada9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -78,18 +78,20 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { // mu protects tables, priorities, and modified. mu sync.RWMutex - // tables maps table names to tables. User tables have arbitrary names. - // mu needs to be locked for accessing. - tables map[string]Table + // tables maps tableIDs to tables. Holds builtin tables only, not user + // tables. mu must be locked for accessing. + tables [numTables]Table // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. mu needs to be locked for accessing. - priorities map[Hook][]string + priorities [NumHooks][]tableID // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that @@ -97,31 +99,34 @@ type IPTables struct { modified bool connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is // really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int + BuiltinChains [NumHooks]int // Underflows maps builtin chains to their underflow rule in Rules // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int + Underflows [NumHooks]int } // ValidHooks returns a bitmap of the builtin hooks for the given table. func (table *Table) ValidHooks() uint32 { hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook + for hook, ruleIdx := range table.BuiltinChains { + if ruleIdx != HookUnset { + hooks |= 1 << hook + } } return hooks } @@ -130,6 +135,8 @@ func (table *Table) ValidHooks() uint32 { // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -142,6 +149,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index e28c23d66..9dce11a97 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -469,7 +469,7 @@ type ndpState struct { rtrSolicit struct { // The timer used to send the next router solicitation message. - timer *time.Timer + timer tcpip.Timer // Used to let the Router Solicitation timer know that it has been stopped. // @@ -503,7 +503,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // @@ -515,38 +515,38 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - // Timer to invalidate the default router. + // Job to invalidate the default router. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - // Timer to invalidate the on-link prefix. + // Job to invalidate the on-link prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // tempSLAACAddrState holds state associated with a temporary SLAAC address. type tempSLAACAddrState struct { - // Timer to deprecate the temporary SLAAC address. + // Job to deprecate the temporary SLAAC address. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the temporary SLAAC address. + // Job to invalidate the temporary SLAAC address. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job - // Timer to regenerate the temporary SLAAC address. + // Job to regenerate the temporary SLAAC address. // // Must not be nil. - regenTimer *tcpip.CancellableTimer + regenJob *tcpip.Job createdAt time.Time @@ -561,15 +561,15 @@ type tempSLAACAddrState struct { // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - // Timer to deprecate the prefix. + // Job to deprecate the prefix. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the prefix. + // Job to invalidate the prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. validUntil time.Time @@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } var done bool - var timer *time.Timer + var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work // cannot be done while holding the NIC's lock. This is effectively the same // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. - timer = time.AfterFunc(0, func() { + timer = ndp.nic.stack.Clock().AfterFunc(0, func() { ndp.nic.mu.Lock() defer ndp.nic.mu.Unlock() @@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update - // the invalidation timer. - rtr.invalidationTimer.StopLocked() - rtr.invalidationTimer.Reset(rl) + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.StopLocked() + rtr.invalidationJob.Cancel() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } - state.invalidationTimer.Reset(rl) + state.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = state } @@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } if l < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(l) + state.invalidationJob.Schedule(l) } ndp.onLinkPrefixes[prefix] = state @@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - s.invalidationTimer.StopLocked() + s.invalidationJob.Cancel() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. // - // Update the invalidation timer. + // Update the invalidation job. - prefixState.invalidationTimer.StopLocked() + prefixState.invalidationJob.Cancel() if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, reset the timer to expire after + // Prefix is valid for a finite lifetime, schedule the job to execute after // the new valid lifetime. - prefixState.invalidationTimer.Reset(vl) + prefixState.invalidationJob.Schedule(vl) } ndp.onLinkPrefixes[prefix] = prefixState @@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.stableAddr.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if !ndp.generateSLAACAddr(prefix, &state) { // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or timers for a prefix we + // further as there is no reason to maintain state or jobs for a prefix we // do not have an address for. return } - // Setup the initial timers to deprecate and invalidate prefix. + // Setup the initial jobs to deprecate and invalidate prefix. if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationTimer.Reset(pl) + state.deprecationJob.Schedule(pl) } if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) + state.invalidationJob.Schedule(vl) state.validUntil = now.Add(vl) } @@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } state := tempSLAACAddrState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.deprecateSLAACAddress(tempAddrState.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ref: ref, } - state.deprecationTimer.Reset(pl) - state.invalidationTimer.Reset(vl) - state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) prefixState.generationAttempts++ prefixState.tempAddrs[generatedAddr.Address] = state @@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat prefixState.stableAddr.ref.deprecated = false } - // If prefix was preferred for some finite lifetime before, stop the - // deprecation timer so it can be reset. - prefixState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() now := time.Now() - // Reset the deprecation timer if prefix has a finite preferred lifetime. + // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { - prefixState.deprecationTimer.Reset(pl) + prefixState.deprecationJob.Schedule(pl) } prefixState.preferredUntil = now.Add(pl) } else { @@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not keep a timer - // in this case. - prefixState.invalidationTimer.StopLocked() + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() prefixState.validUntil = time.Time{} } else { var effectiveVl time.Duration @@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } if effectiveVl != 0 { - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) prefixState.validUntil = now.Add(effectiveVl) } } @@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // Note, we do not need to update the entries in the temporary address map - // after updating the timers because the timers are held as pointers. + // after updating the jobs because the jobs are held as pointers. var regenForAddr tcpip.Address allAddressesRegenerated := true for tempAddr, tempAddrState := range prefixState.tempAddrs { @@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation timer. + // reset the invalidation job. newValidLifetime := validUntil.Sub(now) if newValidLifetime <= 0 { ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) continue } - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.invalidationTimer.Reset(newValidLifetime) + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary // address is the lower of the preferred lifetime of the stable address or @@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer preferred, deprecate it immediately. - // Otherwise, reset the deprecation timer. + // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.ref) } else { tempAddrState.ref.deprecated = false - tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } - tempAddrState.regenTimer.StopLocked() + tempAddrState.regenJob.Cancel() if tempAddrState.regenerated { } else { allAddressesRegenerated = false @@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // immediately after we finish iterating over the temporary addresses. regenForAddr = tempAddr } else { - tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) } } } @@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. // // Panics if the SLAAC prefix is not known. // @@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa } state.stableAddr.ref = nil - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) } @@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi } // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// timers and entry. +// jobs and entry. // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationTimer.StopLocked() - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.regenTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } @@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { ndp.nic.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6f86abc98..644ba7c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Wait for lladdr2's router invalidation timer to fire. The lifetime + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Wait for lladdr3's router invalidation timer to fire. The lifetime + // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Wait for prefix2's most recent invalidation timer plus some buffer to + // Wait for prefix2's most recent invalidation job plus some buffer to // expire. select { case e := <-ndpDisp.prefixC: @@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { for _, addr := range tempAddrs { // Wait for a deprecation then invalidation event, or just an invalidation // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation timers could fire in any + // cases because the deprecation and invalidation jobs could execute in any // order. select { case e := <-ndpDisp.autoGenAddrC: @@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } } -// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's -// regeneration timer gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( nicID = 1 regenAfter = 2 * time.Second @@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { // // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a timer. + // - this regeneration does not depend on a job. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(tempAddr2, newAddr) @@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { } // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration timer gets reset. + // RA, the regeneration job gets scheduled again. // // The maximum lifetime is the sum of the minimum lifetimes for temporary // addresses + the time that has already passed since the last address was - // generated so that the regeneration timer is needed to generate the next + // generated so that the regeneration job is needed to generate the next // address. newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes @@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr(addr2) } -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated // when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { +func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second @@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) { } expectAutoGenAddrEvent(addr, invalidatedAddr) - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index afb7dfeaf..fea0ce7e8 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // Are any packet sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this @@ -1358,16 +1374,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // TransportHeader is nil only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader == nil { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } pkt.TransportHeader = transHeader + pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 31f865260..c477e31d8 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -84,6 +84,16 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1b5da6017..5d6865e35 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,6 +14,7 @@ package stack import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -24,7 +25,7 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { - _ noCopy + _ sync.NoCopy // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. @@ -78,6 +79,10 @@ type PacketBuffer struct { // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } // Clone makes a copy of pk. It clones the Data field, which creates a new @@ -102,14 +107,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NatDone: pk.NatDone, } } - -// noCopy may be embedded into structs which must not be copied -// after the first use. -// -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} - -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} -func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5cbc946b6..9e1b2d25f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +52,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) @@ -329,8 +333,7 @@ type NetworkProtocol interface { } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -341,6 +344,16 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -436,6 +449,15 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cdcfb8321..a6faa22c2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -425,6 +425,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -727,6 +728,11 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + // UniqueID returns a unique identifier. func (s *Stack) UniqueID() uint64 { return s.uniqueIDGenerator.UniqueID() @@ -800,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. @@ -1094,6 +1101,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1125,6 +1137,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2be1c107a..21aafb0a2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } -// A Clock provides the current time. +// A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. @@ -203,6 +203,31 @@ type Clock interface { // NowMonotonic returns a monotonic time value. NowMonotonic() int64 + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) } // Address is a byte slice cast as a string that represents the address of a @@ -316,6 +341,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -549,6 +596,28 @@ type Endpoint interface { SetOwner(owner PacketOwner) } +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) +} + // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -648,6 +717,11 @@ const ( // whether an IPv6 socket is to be restricted to sending and receiving // IPv6 packets only. V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. @@ -673,6 +747,13 @@ const ( // TCP_MAXSEG option. MaxSegOption + // MTUDiscoverOption is used to set/get the path MTU discovery setting. + // + // NOTE: Setting this option to any other value than PMTUDiscoveryDont + // is not supported and will fail as such, and getting this option will + // always return PMTUDiscoveryDont. + MTUDiscoverOption + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control // the default TTL value for multicast messages. The default is 1. MulticastTTLOption @@ -714,6 +795,24 @@ const ( TCPWindowClampOption ) +const ( + // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use + // per-route settings. + PMTUDiscoveryWant int = iota + + // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable + // path MTU discovery. + PMTUDiscoveryDont + + // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do + // path MTU discovery. + PMTUDiscoveryDo + + // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF + // but ignore path MTU. + PMTUDiscoveryProbe +) + // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} @@ -752,7 +851,7 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the @@ -825,7 +924,10 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -1214,6 +1316,9 @@ type UDPStats struct { // ChecksumErrors is the number of datagrams dropped due to bad checksums. ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 7f172f978..f32d58091 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -20,7 +20,7 @@ package tcpip import ( - _ "time" // Used with go:linkname. + "time" // Used with go:linkname. _ "unsafe" // Required for go:linkname. ) @@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 { _, _, mono := now() return mono } + +// AfterFunc implements Clock.AfterFunc. +func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 59f3b391f..f1dd7c310 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -15,54 +15,54 @@ package tcpip import ( - "sync" "time" + + "gvisor.dev/gvisor/pkg/sync" ) -// cancellableTimerInstance is a specific instance of CancellableTimer. +// jobInstance is a specific instance of Job. // -// Different instances are created each time CancellableTimer is Reset so each -// timer has its own earlyReturn signal. This is to address a bug when a -// CancellableTimer is stopped and reset in quick succession resulting in a -// timer instance's earlyReturn signal being affected or seen by another timer -// instance. +// Different instances are created each time Job is scheduled so each timer has +// its own earlyReturn signal. This is to address a bug when a Job is stopped +// and reset in quick succession resulting in a timer instance's earlyReturn +// signal being affected or seen by another timer instance. // // Consider the following sceneario where timer instances share a common // earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a // lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second // (B), third (C), and fourth (D) instance of the timer firing, respectively): // T1: Obtain L -// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T1: Create a new Job w/ lock L (create instance A) // T2: instance A fires, blocked trying to obtain L. // T1: Attempt to stop instance A (set earlyReturn = true) -// T1: Reset timer (create instance B) +// T1: Schedule timer (create instance B) // T3: instance B fires, blocked trying to obtain L. // T1: Attempt to stop instance B (set earlyReturn = true) -// T1: Reset timer (create instance C) +// T1: Schedule timer (create instance C) // T4: instance C fires, blocked trying to obtain L. // T1: Attempt to stop instance C (set earlyReturn = true) -// T1: Reset timer (create instance D) +// T1: Schedule timer (create instance D) // T5: instance D fires, blocked trying to obtain L. // T1: Release L // -// Now that T1 has released L, any of the 4 timer instances can take L and check -// earlyReturn. If the timers simply check earlyReturn and then do nothing -// further, then instance D will never early return even though it was not -// requested to stop. If the timers reset earlyReturn before early returning, -// then all but one of the timers will do work when only one was expected to. -// If CancellableTimer resets earlyReturn when resetting, then all the timers +// Now that T1 has released L, any of the 4 timer instances can take L and +// check earlyReturn. If the timers simply check earlyReturn and then do +// nothing further, then instance D will never early return even though it was +// not requested to stop. If the timers reset earlyReturn before early +// returning, then all but one of the timers will do work when only one was +// expected to. If Job resets earlyReturn when resetting, then all the timers // will fire (again, when only one was expected to). // // To address the above concerns the simplest solution was to give each timer // its own earlyReturn signal. -type cancellableTimerInstance struct { - timer *time.Timer +type jobInstance struct { + timer Timer // Used to inform the timer to early return when it gets stopped while the // lock the timer tries to obtain when fired is held (T1 is a goroutine that // tries to cancel the timer and T2 is the goroutine that handles the timer // firing): - // T1: Obtain the lock, then call StopLocked() + // T1: Obtain the lock, then call Cancel() // T2: timer fires, and gets blocked on obtaining the lock // T1: Releases lock // T2: Obtains lock does unintended work @@ -73,27 +73,33 @@ type cancellableTimerInstance struct { earlyReturn *bool } -// stop stops the timer instance t from firing if it hasn't fired already. If it +// stop stops the job instance j from firing if it hasn't fired already. If it // has fired and is blocked at obtaining the lock, earlyReturn will be set to // true so that it will early return when it obtains the lock. -func (t *cancellableTimerInstance) stop() { - if t.timer != nil { - t.timer.Stop() - *t.earlyReturn = true +func (j *jobInstance) stop() { + if j.timer != nil { + j.timer.Stop() + *j.earlyReturn = true } } -// CancellableTimer is a timer that does some work and can be safely cancelled -// when it fires at the same time some "related work" is being done. +// Job represents some work that can be scheduled for execution. The work can +// be safely cancelled when it fires at the same time some "related work" is +// being done. // // The term "related work" is defined as some work that needs to be done while // holding some lock that the timer must also hold while doing some work. // -// Note, it is not safe to copy a CancellableTimer as its timer instance creates -// a closure over the address of the CancellableTimer. -type CancellableTimer struct { +// Note, it is not safe to copy a Job as its timer instance creates +// a closure over the address of the Job. +type Job struct { + _ sync.NoCopy + + // The clock used to schedule the backing timer + clock Clock + // The active instance of a cancellable timer. - instance cancellableTimerInstance + instance jobInstance // locker is the lock taken by the timer immediately after it fires and must // be held when attempting to stop the timer. @@ -110,75 +116,91 @@ type CancellableTimer struct { fn func() } -// StopLocked prevents the Timer from firing if it has not fired already. +// Cancel prevents the Job from executing if it has not executed already. // -// If the timer is blocked on obtaining the t.locker lock when StopLocked is -// called, it will early return instead of calling t.fn. +// Cancel requires appropriate locking to be in place for any resources managed +// by the Job. If the Job is blocked on obtaining the lock when Cancel is +// called, it will early return. // // Note, t will be modified. // -// t.locker MUST be locked. -func (t *CancellableTimer) StopLocked() { - t.instance.stop() +// j.locker MUST be locked. +func (j *Job) Cancel() { + j.instance.stop() // Nothing to do with the stopped instance anymore. - t.instance = cancellableTimerInstance{} + j.instance = jobInstance{} } -// Reset changes the timer to expire after duration d. +// Schedule schedules the Job for execution after duration d. This can be +// called on cancelled or completed Jobs to schedule them again. // -// Note, t will be modified. +// Schedule should be invoked only on unscheduled, cancelled, or completed +// Jobs. To be safe, callers should always call Cancel before calling Schedule. // -// Reset should only be called on stopped or expired timers. To be safe, callers -// should always call StopLocked before calling Reset. -func (t *CancellableTimer) Reset(d time.Duration) { +// Note, j will be modified. +func (j *Job) Schedule(d time.Duration) { // Create a new instance. earlyReturn := false // Capture the locker so that updating the timer does not cause a data race // when a timer fires and tries to obtain the lock (read the timer's locker). - locker := t.locker - t.instance = cancellableTimerInstance{ - timer: time.AfterFunc(d, func() { + locker := j.locker + j.instance = jobInstance{ + timer: j.clock.AfterFunc(d, func() { locker.Lock() defer locker.Unlock() if earlyReturn { // If we reach this point, it means that the timer fired while another - // goroutine called StopLocked while it had the lock. Simply return - // here and do nothing further. + // goroutine called Cancel while it had the lock. Simply return here + // and do nothing further. earlyReturn = false return } - t.fn() + j.fn() }), earlyReturn: &earlyReturn, } } -// Lock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Lock() {} - -// Unlock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Unlock() {} - -// NewCancellableTimer returns an unscheduled CancellableTimer with the given -// locker and fn. -// -// fn MUST NOT attempt to lock locker. -// -// Callers must call Reset to schedule the timer to fire. -func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer { - return &CancellableTimer{locker: locker, fn: fn} +// NewJob returns a new Job that can be used to schedule f to run in its own +// gorountine. l will be locked before calling f then unlocked after f returns. +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// message = "bar" +// mu.Unlock() +// +// // Output: bar +// +// f MUST NOT attempt to lock l. +// +// l MUST be locked prior to calling the returned job's Cancel(). +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// job.Cancel() +// mu.Unlock() +func NewJob(c Clock, l sync.Locker, f func()) *Job { + return &Job{ + clock: c, + locker: l, + fn: f, + } } diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index b4940e397..a82384c49 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -28,8 +28,8 @@ const ( longDuration = 1 * time.Second ) -func TestCancellableTimerReassignment(t *testing.T) { - var timer tcpip.CancellableTimer +func TestJobReschedule(t *testing.T) { + var clock tcpip.StdClock var wg sync.WaitGroup var lock sync.Mutex @@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - timer = *tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { wg.Done() }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) lock.Unlock() }() } wg.Wait() } -func TestCancellableTimerFire(t *testing.T) { +func TestJobExecution(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(middleDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(middleDuration) lock.Lock() - timer.StopLocked() + job.Cancel() lock.Unlock() - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { } } -func TestCancellableTimerResetFromShortDuration(t *testing.T) { +func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() // Wait for timer to fire if it wasn't correctly stopped. @@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { case <-time.After(middleDuration): } - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { } } -func TestCancellableTimerImmediatelyStop(t *testing.T) { +func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() } @@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) { } } -func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { +func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() for i := 0; i < 10; i++ { - timer.Reset(middleDuration) + job.Schedule(middleDuration) lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration * 2) - timer.StopLocked() + job.Cancel() lock.Unlock() } @@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration) - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() @@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { } } -func TestManyCancellableTimerResetUnderLock(t *testing.T) { +func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..4612be4e7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } @@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index baf08eda6..0e46e6355 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,6 +25,8 @@ package packet import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -43,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -71,11 +76,17 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb sndBufSize: 32 * 1024, } + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { return nil, err } @@ -132,8 +154,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -158,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(b/129292371): Implement. return 0, nil, tcpip.ErrInvalidOptionValue @@ -215,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound { - return tcpip.ErrAlreadyBound + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil } // Unregister endpoint with all the nics. ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { @@ -228,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } ep.bound = true + ep.boundNIC = addr.NIC return nil } @@ -264,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. @@ -274,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (ep *endpoint) takeLastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return ep.takeLastError() + } return tcpip.ErrNotSupported } @@ -289,7 +381,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // HandlePacket implements stack.PacketEndpoint.HandlePacket. @@ -323,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header from the Header. + pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) + combinedVV := pkt.Header.View().ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var combinedVV buffer.VectorisedView + if pkt.PktType != tcpip.PacketOutgoing { + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + combinedVV = linkHeader.ToVectorisedView() + } + if pkt.PktType == tcpip.PacketOutgoing { + // For outgoing packets the Link, Network and Transport + // headers are in the pkt.Header fields normally unless + // a Raw socket is in use in which case pkt.Header could + // be nil. + combinedVV.AppendView(pkt.Header.View()) } - combinedVV := linkHeader.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 766c7648e..f85a68554 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -108,6 +109,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. @@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -263,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -353,7 +351,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if !e.associated { + if e.hdrIncluded { if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { @@ -458,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -508,11 +506,24 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -577,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -616,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() @@ -676,7 +700,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { } combinedVV.Append(pkt.Data) packet.data = combinedVV - packet.timestampNS = e.stack.NowNanoseconds() + packet.timestampNS = e.stack.Clock().NowNanoseconds() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3601207be..18ff89ffc 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -58,7 +58,6 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ - "//pkg/binary", "//pkg/log", "//pkg/rand", "//pkg/sleep", @@ -87,6 +86,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], + shard_count = 10, deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 81b740115..1798510bc 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } } // Wait for notification. @@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } case wakerForNewSegment: if err := h.processSegments(); err != nil { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 43b76bee5..98aecab9e 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -15,7 +15,8 @@ package tcp import ( - "gvisor.dev/gvisor/pkg/binary" + "encoding/binary" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bd3ec5a8d..0f7487963 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1209,6 +1209,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1589,6 +1597,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1785,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } @@ -1896,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := header.TCPDefaultMSS return v, nil + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1941,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err + return e.takeLastError() case *tcpip.BindToDeviceOption: e.LockUser() @@ -2531,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 169adb16b..e67ec42b1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3095,6 +3095,63 @@ func TestMaxRTO(t *testing.T) { } } +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cae29fbff..6e692da07 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -612,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -809,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } @@ -906,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.MulticastTTLOption: e.mu.Lock() v := int(e.multicastTTL) @@ -1366,6 +1380,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + // Verify checksum unless RX checksum offload is enabled. // On IPv4, UDP checksum is optional, and a zero value means // the transmitter omitted the checksum generation (RFC768). @@ -1384,10 +1407,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() @@ -1428,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index db59eb5a0..90781cf49 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -1721,9 +1793,11 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Invalidate the packet length field in the UDP header by adding one. + + // Invalidate the UDP header length field. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -1803,9 +1877,16 @@ func TestIncrementChecksumErrorsV4(t *testing.T) { payload := newPayload() h := unicastV4.header4Tuple(incoming) buf := c.buildV4Packet(payload, &h) - // Invalidate the checksum field in the UDP header by adding one. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -1834,9 +1915,11 @@ func TestIncrementChecksumErrorsV6(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Invalidate the checksum field in the UDP header by adding one. + + // Invalidate the UDP header checksum field. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(u.Checksum() + 1) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go index 8fed29ff5..70945f234 100644 --- a/pkg/test/criutil/criutil.go +++ b/pkg/test/criutil/criutil.go @@ -22,6 +22,9 @@ import ( "fmt" "os" "os/exec" + "path" + "regexp" + "strconv" "strings" "time" @@ -33,28 +36,44 @@ import ( type Crictl struct { logger testutil.Logger endpoint string + runpArgs []string cleanup []func() } -// resolvePath attempts to find binary paths. It may set the path to invalid, +// ResolvePath attempts to find binary paths. It may set the path to invalid, // which will cause the execution to fail with a sensible error. -func resolvePath(executable string) string { +func ResolvePath(executable string) string { + runtime, err := dockerutil.RuntimePath() + if err == nil { + // Check first the directory of the runtime itself. + if dir := path.Dir(runtime); dir != "" && dir != "." { + guess := path.Join(dir, executable) + if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 { + return guess + } + } + } + + // Try to find via the path. guess, err := exec.LookPath(executable) - if err != nil { - guess = fmt.Sprintf("/usr/local/bin/%s", executable) + if err == nil { + return guess } - return guess + + // Return a default path. + return fmt.Sprintf("/usr/local/bin/%s", executable) } // NewCrictl returns a Crictl configured with a timeout and an endpoint over // which it will talk to containerd. -func NewCrictl(logger testutil.Logger, endpoint string) *Crictl { +func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl { // Attempt to find the executable, but don't bother propagating the // error at this point. The first command executed will return with a // binary not found error. return &Crictl{ logger: logger, endpoint: endpoint, + runpArgs: runpArgs, } } @@ -67,8 +86,8 @@ func (cc *Crictl) CleanUp() { } // RunPod creates a sandbox. It corresponds to `crictl runp`. -func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { - podID, err := cc.run("runp", sbSpecFile) +func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) { + podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile) if err != nil { return "", fmt.Errorf("runp failed: %v", err) } @@ -79,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { // Create creates a container within a sandbox. It corresponds to `crictl // create`. func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) { - podID, err := cc.run("create", podID, contSpecFile, sbSpecFile) + // In version 1.16.0, crictl annoying starting attempting to pull the + // container, even if it was already available locally. We therefore + // need to parse the version and add an appropriate --no-pull argument + // since the image has already been loaded locally. + out, err := cc.run("-v") + if err != nil { + return "", err + } + r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(out) + if len(vs) != 4 { + return "", fmt.Errorf("crictl -v had unexpected output: %s", out) + } + major, err := strconv.ParseUint(vs[1], 10, 64) if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + + args := []string{"create"} + if (major == 1 && minor >= 16) || major > 1 { + args = append(args, "--no-pull") + } + args = append(args, podID) + args = append(args, contSpecFile) + args = append(args, sbSpecFile) + + podID, err = cc.run(args...) + if err != nil { + time.Sleep(10 * time.Minute) // XXX return "", fmt.Errorf("create failed: %v", err) } + // Strip the trailing newline from crictl output. return strings.TrimSpace(podID), nil } @@ -179,7 +230,7 @@ func (cc *Crictl) Import(image string) error { // be pushing a lot of bytes in order to import the image. The connect // timeout stays the same and is inherited from the Crictl instance. cmd := testutil.Command(cc.logger, - resolvePath("ctr"), + ResolvePath("ctr"), fmt.Sprintf("--connect-timeout=%s", 30*time.Second), fmt.Sprintf("--address=%s", cc.endpoint), "-n", "k8s.io", "images", "import", "-") @@ -260,7 +311,7 @@ func (cc *Crictl) StopContainer(contID string) error { // StartPodAndContainer starts a sandbox and container in that sandbox. It // returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { +func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) { if err := cc.Import(image); err != nil { return "", "", err } @@ -277,7 +328,7 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, } cc.cleanup = append(cc.cleanup, cleanup) - podID, err := cc.RunPod(sbSpecFile) + podID, err := cc.RunPod(runtime, sbSpecFile) if err != nil { return "", "", err } @@ -307,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error { // run runs crictl with the given args. func (cc *Crictl) run(args ...string) (string, error) { defaultArgs := []string{ - resolvePath("crictl"), + ResolvePath("crictl"), "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), } diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 7c8758e35..83b80c8bc 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -5,10 +5,21 @@ package(licenses = ["notice"]) go_library( name = "dockerutil", testonly = 1, - srcs = ["dockerutil.go"], + srcs = [ + "container.go", + "dockerutil.go", + "exec.go", + "network.go", + ], visibility = ["//:sandbox"], deps = [ "//pkg/test/testutil", - "@com_github_kr_pty//:go_default_library", + "@com_github_docker_docker//api/types:go_default_library", + "@com_github_docker_docker//api/types/container:go_default_library", + "@com_github_docker_docker//api/types/mount:go_default_library", + "@com_github_docker_docker//api/types/network:go_default_library", + "@com_github_docker_docker//client:go_default_library", + "@com_github_docker_docker//pkg/stdcopy:go_default_library", + "@com_github_docker_go_connections//nat:go_default_library", ], ) diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go new file mode 100644 index 000000000..17acdaf6f --- /dev/null +++ b/pkg/test/dockerutil/container.go @@ -0,0 +1,501 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net" + "os" + "path" + "regexp" + "strconv" + "strings" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Container represents a Docker Container allowing +// user to configure and control as one would with the 'docker' +// client. Container is backed by the offical golang docker API. +// See: https://pkg.go.dev/github.com/docker/docker. +type Container struct { + Name string + Runtime string + + logger testutil.Logger + client *client.Client + id string + mounts []mount.Mount + links []string + cleanups []func() + copyErr error + + // Stores streams attached to the container. Used by WaitForOutputSubmatch. + streams types.HijackedResponse + + // stores previously read data from the attached streams. + streamBuf bytes.Buffer +} + +// RunOpts are options for running a container. +type RunOpts struct { + // Image is the image relative to images/. This will be mangled + // appropriately, to ensure that only first-party images are used. + Image string + + // Memory is the memory limit in bytes. + Memory int + + // Cpus in which to allow execution. ("0", "1", "0-2"). + CpusetCpus string + + // Ports are the ports to be allocated. + Ports []int + + // WorkDir sets the working directory. + WorkDir string + + // ReadOnly sets the read-only flag. + ReadOnly bool + + // Env are additional environment variables. + Env []string + + // User is the user to use. + User string + + // Privileged enables privileged mode. + Privileged bool + + // CapAdd are the extra set of capabilities to add. + CapAdd []string + + // CapDrop are the extra set of capabilities to drop. + CapDrop []string + + // Mounts is the list of directories/files to be mounted inside the container. + Mounts []mount.Mount + + // Links is the list of containers to be connected to the container. + Links []string +} + +// MakeContainer sets up the struct for a Docker container. +// +// Names of containers will be unique. +func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return nil + } + + client.NegotiateAPIVersion(ctx) + + return &Container{ + logger: logger, + Name: name, + Runtime: *runtime, + client: client, + } +} + +// Spawn is analogous to 'docker run -d'. +func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { + if err := c.create(ctx, r, args); err != nil { + return err + } + return c.Start(ctx) +} + +// SpawnProcess is analogous to 'docker run -it'. It returns a process +// which represents the root process. +func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) { + config, hostconf, netconf := c.ConfigsFrom(r, args...) + config.Tty = true + config.OpenStdin = true + + if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil { + return Process{}, err + } + + if err := c.Start(ctx); err != nil { + return Process{}, err + } + + return Process{container: c, conn: c.streams}, nil +} + +// Run is analogous to 'docker run'. +func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { + if err := c.create(ctx, r, args); err != nil { + return "", err + } + + if err := c.Start(ctx); err != nil { + return "", err + } + + if err := c.Wait(ctx); err != nil { + return "", err + } + + return c.Logs(ctx) +} + +// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom' +// and Start. +func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) { + return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{} +} + +// MakeLink formats a link to add to a RunOpts. +func (c *Container) MakeLink(target string) string { + return fmt.Sprintf("%s:%s", c.Name, target) +} + +// CreateFrom creates a container from the given configs. +func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name) + if err != nil { + return err + } + c.id = cont.ID + return nil +} + +// Create is analogous to 'docker create'. +func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { + return c.create(ctx, r, args) +} + +func (c *Container) create(ctx context.Context, r RunOpts, args []string) error { + conf := c.config(r, args) + hostconf := c.hostConfig(r) + cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) + if err != nil { + return err + } + c.id = cont.ID + return nil +} + +func (c *Container) config(r RunOpts, args []string) *container.Config { + ports := nat.PortSet{} + for _, p := range r.Ports { + port := nat.Port(fmt.Sprintf("%d", p)) + ports[port] = struct{}{} + } + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + + return &container.Config{ + Image: testutil.ImageByName(r.Image), + Cmd: args, + ExposedPorts: ports, + Env: env, + WorkingDir: r.WorkDir, + User: r.User, + } +} + +func (c *Container) hostConfig(r RunOpts) *container.HostConfig { + c.mounts = append(c.mounts, r.Mounts...) + + return &container.HostConfig{ + Runtime: c.Runtime, + Mounts: c.mounts, + PublishAllPorts: true, + Links: r.Links, + CapAdd: r.CapAdd, + CapDrop: r.CapDrop, + Privileged: r.Privileged, + ReadonlyRootfs: r.ReadOnly, + Resources: container.Resources{ + Memory: int64(r.Memory), // In bytes. + CpusetCpus: r.CpusetCpus, + }, + } +} + +// Start is analogous to 'docker start'. +func (c *Container) Start(ctx context.Context) error { + + // Open a connection to the container for parsing logs and for TTY. + streams, err := c.client.ContainerAttach(ctx, c.id, + types.ContainerAttachOptions{ + Stream: true, + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + return fmt.Errorf("failed to connect to container: %v", err) + } + + c.streams = streams + c.cleanups = append(c.cleanups, func() { + c.streams.Close() + }) + + return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}) +} + +// Stop is analogous to 'docker stop'. +func (c *Container) Stop(ctx context.Context) error { + return c.client.ContainerStop(ctx, c.id, nil) +} + +// Pause is analogous to'docker pause'. +func (c *Container) Pause(ctx context.Context) error { + return c.client.ContainerPause(ctx, c.id) +} + +// Unpause is analogous to 'docker unpause'. +func (c *Container) Unpause(ctx context.Context) error { + return c.client.ContainerUnpause(ctx, c.id) +} + +// Checkpoint is analogous to 'docker checkpoint'. +func (c *Container) Checkpoint(ctx context.Context, name string) error { + return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true}) +} + +// Restore is analogous to 'docker start --checkname [name]'. +func (c *Container) Restore(ctx context.Context, name string) error { + return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name}) +} + +// Logs is analogous 'docker logs'. +func (c *Container) Logs(ctx context.Context) (string, error) { + var out bytes.Buffer + err := c.logs(ctx, &out, &out) + return out.String(), err +} + +func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error { + opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true} + writer, err := c.client.ContainerLogs(ctx, c.id, opts) + if err != nil { + return err + } + defer writer.Close() + _, err = stdcopy.StdCopy(stdout, stderr, writer) + + return err +} + +// ID returns the container id. +func (c *Container) ID() string { + return c.id +} + +// SandboxPid returns the container's pid. +func (c *Container) SandboxPid(ctx context.Context) (int, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, err + } + return resp.ContainerJSONBase.State.Pid, nil +} + +// FindIP returns the IP address of the container. +func (c *Container) FindIP(ctx context.Context) (net.IP, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return nil, err + } + + ip := net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress) + if ip == nil { + return net.IP{}, fmt.Errorf("invalid IP: %q", ip) + } + return ip, nil +} + +// FindPort returns the host port that is mapped to 'sandboxPort'. +func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) { + desc, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, fmt.Errorf("error retrieving port: %v", err) + } + + format := fmt.Sprintf("%d/tcp", sandboxPort) + ports, ok := desc.NetworkSettings.Ports[nat.Port(format)] + if !ok { + return -1, fmt.Errorf("error retrieving port: %v", err) + + } + + port, err := strconv.Atoi(ports[0].HostPort) + if err != nil { + return -1, fmt.Errorf("error parsing port %q: %v", port, err) + } + return port, nil +} + +// CopyFiles copies in and mounts the given files. They are always ReadOnly. +func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { + dir, err := ioutil.TempDir("", c.Name) + if err != nil { + c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) + return + } + c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) }) + if err := os.Chmod(dir, 0755); err != nil { + c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) + return + } + for _, name := range sources { + src, err := testutil.FindFile(name) + if err != nil { + c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) + return + } + dst := path.Join(dir, path.Base(name)) + if err := testutil.Copy(src, dst); err != nil { + c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) + return + } + c.logger.Logf("copy: %s -> %s", src, dst) + } + opts.Mounts = append(opts.Mounts, mount.Mount{ + Type: mount.TypeBind, + Source: dir, + Target: target, + ReadOnly: false, + }) +} + +// Status inspects the container returns its status. +func (c *Container) Status(ctx context.Context) (types.ContainerState, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return types.ContainerState{}, err + } + return *resp.State, err +} + +// Wait waits for the container to exit. +func (c *Container) Wait(ctx context.Context) error { + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case err := <-errChan: + return err + case <-statusChan: + return nil + } +} + +// WaitTimeout waits for the container to exit with a timeout. +func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error { + timeoutChan := time.After(timeout) + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case err := <-errChan: + return err + case <-statusChan: + return nil + case <-timeoutChan: + return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds()) + } +} + +// WaitForOutput searches container logs for pattern and returns or timesout. +func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) { + matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout) + if err != nil { + return "", err + } + if len(matches) == 0 { + return "", fmt.Errorf("didn't find pattern %s logs", pattern) + } + return matches[0], nil +} + +// WaitForOutputSubmatch searches container logs for the given +// pattern or times out. It returns any regexp submatches as well. +func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) { + re := regexp.MustCompile(pattern) + if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + return matches, nil + } + + for exp := time.Now().Add(timeout); time.Now().Before(exp); { + c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader) + + if err != nil { + // check that it wasn't a timeout + if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() { + return nil, err + } + } + + if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + return matches, nil + } + } + + return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String()) +} + +// Kill kills the container. +func (c *Container) Kill(ctx context.Context) error { + return c.client.ContainerKill(ctx, c.id, "") +} + +// Remove is analogous to 'docker rm'. +func (c *Container) Remove(ctx context.Context) error { + // Remove the image. + remove := types.ContainerRemoveOptions{ + RemoveVolumes: c.mounts != nil, + RemoveLinks: c.links != nil, + Force: true, + } + return c.client.ContainerRemove(ctx, c.Name, remove) +} + +// CleanUp kills and deletes the container (best effort). +func (c *Container) CleanUp(ctx context.Context) { + // Kill the container. + if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { + // Just log; can't do anything here. + c.logger.Logf("error killing container %q: %v", c.Name, err) + } + // Remove the image. + if err := c.Remove(ctx); err != nil { + c.logger.Logf("error removing container %q: %v", c.Name, err) + } + // Forget all mounts. + c.mounts = nil + // Execute all cleanups. + for _, c := range c.cleanups { + c() + } + c.cleanups = nil +} diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 819dd0a59..df09babf3 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -22,17 +22,10 @@ import ( "io" "io/ioutil" "log" - "net" - "os" "os/exec" - "path" "regexp" "strconv" - "strings" - "syscall" - "time" - "github.com/kr/pty" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -127,595 +120,7 @@ func Save(logger testutil.Logger, image string, w io.Writer) error { return cmd.Run() } -// MountMode describes if the mount should be ro or rw. -type MountMode int - -const ( - // ReadOnly is what the name says. - ReadOnly MountMode = iota - // ReadWrite is what the name says. - ReadWrite -) - -// String returns the mount mode argument for this MountMode. -func (m MountMode) String() string { - switch m { - case ReadOnly: - return "ro" - case ReadWrite: - return "rw" - } - panic(fmt.Sprintf("invalid mode: %d", m)) -} - -// DockerNetwork contains the name of a docker network. -type DockerNetwork struct { - logger testutil.Logger - Name string - Subnet *net.IPNet - containers []*Docker -} - -// NewDockerNetwork sets up the struct for a Docker network. Names of networks -// will be unique. -func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { - return &DockerNetwork{ - logger: logger, - Name: testutil.RandomID(logger.Name()), - } -} - -// Create calls 'docker network create'. -func (n *DockerNetwork) Create(args ...string) error { - a := []string{"docker", "network", "create"} - if n.Subnet != nil { - a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) - } - a = append(a, args...) - a = append(a, n.Name) - return testutil.Command(n.logger, a...).Run() -} - -// Connect calls 'docker network connect' with the arguments provided. -func (n *DockerNetwork) Connect(container *Docker, args ...string) error { - a := []string{"docker", "network", "connect"} - a = append(a, args...) - a = append(a, n.Name, container.Name) - if err := testutil.Command(n.logger, a...).Run(); err != nil { - return err - } - n.containers = append(n.containers, container) - return nil -} - -// Cleanup cleans up the docker network and all the containers attached to it. -func (n *DockerNetwork) Cleanup() error { - for _, c := range n.containers { - // Don't propagate the error, it might be that the container - // was already cleaned up. - if err := c.Kill(); err != nil { - n.logger.Logf("unable to kill container during cleanup: %s", err) - } - } - - if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { - return err - } - return nil -} - -// Docker contains the name and the runtime of a docker container. -type Docker struct { - logger testutil.Logger - Runtime string - Name string - copyErr error - cleanups []func() -} - -// MakeDocker sets up the struct for a Docker container. -// -// Names of containers will be unique. -func MakeDocker(logger testutil.Logger) *Docker { - // Slashes are not allowed in container names. - name := testutil.RandomID(logger.Name()) - name = strings.ReplaceAll(name, "/", "-") - - return &Docker{ - logger: logger, - Name: name, - Runtime: *runtime, - } -} - -// CopyFiles copies in and mounts the given files. They are always ReadOnly. -func (d *Docker) CopyFiles(opts *RunOpts, targetDir string, sources ...string) { - dir, err := ioutil.TempDir("", d.Name) - if err != nil { - d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) - return - } - d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) }) - if err := os.Chmod(dir, 0755); err != nil { - d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) - return - } - for _, name := range sources { - src, err := testutil.FindFile(name) - if err != nil { - d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) - return - } - dst := path.Join(dir, path.Base(name)) - if err := testutil.Copy(src, dst); err != nil { - d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) - return - } - d.logger.Logf("copy: %s -> %s", src, dst) - } - opts.Mounts = append(opts.Mounts, Mount{ - Source: dir, - Target: targetDir, - Mode: ReadOnly, - }) -} - -// Mount describes a mount point inside the container. -type Mount struct { - // Source is the path outside the container. - Source string - - // Target is the path inside the container. - Target string - - // Mode tells whether the mount inside the container should be readonly. - Mode MountMode -} - -// Link informs dockers that a given container needs to be made accessible from -// the container being configured. -type Link struct { - // Source is the container to connect to. - Source *Docker - - // Target is the alias for the container. - Target string -} - -// RunOpts are options for running a container. -type RunOpts struct { - // Image is the image relative to images/. This will be mangled - // appropriately, to ensure that only first-party images are used. - Image string - - // Memory is the memory limit in kB. - Memory int - - // Ports are the ports to be allocated. - Ports []int - - // WorkDir sets the working directory. - WorkDir string - - // ReadOnly sets the read-only flag. - ReadOnly bool - - // Env are additional environment variables. - Env []string - - // User is the user to use. - User string - - // Privileged enables privileged mode. - Privileged bool - - // CapAdd are the extra set of capabilities to add. - CapAdd []string - - // CapDrop are the extra set of capabilities to drop. - CapDrop []string - - // Pty indicates that a pty will be allocated. If this is non-nil, then - // this will run after start-up with the *exec.Command and Pty file - // passed in to the function. - Pty func(*exec.Cmd, *os.File) - - // Foreground indicates that the container should be run in the - // foreground. If this is true, then the output will be available as a - // return value from the Run function. - Foreground bool - - // Mounts is the list of directories/files to be mounted inside the container. - Mounts []Mount - - // Links is the list of containers to be connected to the container. - Links []Link - - // Extra are extra arguments that may be passed. - Extra []string -} - -// args returns common arguments. -// -// Note that this does not define the complete behavior. -func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { - isExec := command == "exec" - isRun := command == "run" - - if isRun || isExec { - rv = append(rv, "-i") - } - if r.Pty != nil { - rv = append(rv, "-t") - } - if r.User != "" { - rv = append(rv, fmt.Sprintf("--user=%s", r.User)) - } - if r.Privileged { - rv = append(rv, "--privileged") - } - for _, c := range r.CapAdd { - rv = append(rv, fmt.Sprintf("--cap-add=%s", c)) - } - for _, c := range r.CapDrop { - rv = append(rv, fmt.Sprintf("--cap-drop=%s", c)) - } - for _, e := range r.Env { - rv = append(rv, fmt.Sprintf("--env=%s", e)) - } - if r.WorkDir != "" { - rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir)) - } - if !isExec { - if r.Memory != 0 { - rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory)) - } - for _, p := range r.Ports { - rv = append(rv, fmt.Sprintf("--publish=%d", p)) - } - if r.ReadOnly { - rv = append(rv, fmt.Sprintf("--read-only")) - } - if len(p) > 0 { - rv = append(rv, "--entrypoint=") - } - } - - // Always attach the test environment & Extra. - rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name)) - rv = append(rv, r.Extra...) - - // Attach necessary bits. - if isExec { - rv = append(rv, d.Name) - } else { - for _, m := range r.Mounts { - rv = append(rv, fmt.Sprintf("-v=%s:%s:%v", m.Source, m.Target, m.Mode)) - } - for _, l := range r.Links { - rv = append(rv, fmt.Sprintf("--link=%s:%s", l.Source.Name, l.Target)) - } - - if len(d.Runtime) > 0 { - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) - } - rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) - rv = append(rv, testutil.ImageByName(r.Image)) - } - - // Attach other arguments. - rv = append(rv, p...) - return rv -} - -// run runs a complete command. -func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) { - if d.copyErr != nil { - return "", d.copyErr - } - basicArgs := []string{"docker"} - if command == "spawn" { - command = "run" - basicArgs = append(basicArgs, command) - basicArgs = append(basicArgs, "-d") - } else { - basicArgs = append(basicArgs, command) - } - customArgs := d.argsFor(&r, command, p) - cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...) - if r.Pty != nil { - // If allocating a terminal, then we just ignore the output - // from the command. - ptmx, err := pty.Start(cmd.Cmd) - if err != nil { - return "", err - } - defer cmd.Wait() // Best effort. - r.Pty(cmd.Cmd, ptmx) - } else { - // Can't support PTY or streaming. - out, err := cmd.CombinedOutput() - return string(out), err - } - return "", nil -} - -// Create calls 'docker create' with the arguments provided. -func (d *Docker) Create(r RunOpts, args ...string) error { - out, err := d.run(r, "create", args...) - if strings.Contains(out, "Unable to find image") { - return fmt.Errorf("unable to find image, did you remember to `make load-%s`: %w", r.Image, err) - } - return err -} - -// Start calls 'docker start'. -func (d *Docker) Start() error { - return testutil.Command(d.logger, "docker", "start", d.Name).Run() -} - -// Stop calls 'docker stop'. -func (d *Docker) Stop() error { - return testutil.Command(d.logger, "docker", "stop", d.Name).Run() -} - -// Run calls 'docker run' with the arguments provided. -func (d *Docker) Run(r RunOpts, args ...string) (string, error) { - return d.run(r, "run", args...) -} - -// Spawn starts the container and detaches. -func (d *Docker) Spawn(r RunOpts, args ...string) error { - _, err := d.run(r, "spawn", args...) - return err -} - -// Logs calls 'docker logs'. -func (d *Docker) Logs() (string, error) { - // Don't capture the output; since it will swamp the logs. - out, err := exec.Command("docker", "logs", d.Name).CombinedOutput() - return string(out), err -} - -// Exec calls 'docker exec' with the arguments provided. -func (d *Docker) Exec(r RunOpts, args ...string) (string, error) { - return d.run(r, "exec", args...) -} - -// Pause calls 'docker pause'. -func (d *Docker) Pause() error { - return testutil.Command(d.logger, "docker", "pause", d.Name).Run() -} - -// Unpause calls 'docker pause'. -func (d *Docker) Unpause() error { - return testutil.Command(d.logger, "docker", "unpause", d.Name).Run() -} - -// Checkpoint calls 'docker checkpoint'. -func (d *Docker) Checkpoint(name string) error { - return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run() -} - -// Restore calls 'docker start --checkname [name]'. -func (d *Docker) Restore(name string) error { - return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run() -} - -// Kill calls 'docker kill'. -func (d *Docker) Kill() error { - // Skip logging this command, it will likely be an error. - out, err := exec.Command("docker", "kill", d.Name).CombinedOutput() - if err != nil && !strings.Contains(string(out), "is not running") { - return err - } - return nil -} - -// Remove calls 'docker rm'. -func (d *Docker) Remove() error { - return testutil.Command(d.logger, "docker", "rm", d.Name).Run() -} - -// CleanUp kills and deletes the container (best effort). -func (d *Docker) CleanUp() { - // Kill the container. - if err := d.Kill(); err != nil { - // Just log; can't do anything here. - d.logger.Logf("error killing container %q: %v", d.Name, err) - } - // Remove the image. - if err := d.Remove(); err != nil { - d.logger.Logf("error removing container %q: %v", d.Name, err) - } - // Execute all cleanups. - for _, c := range d.cleanups { - c() - } - d.cleanups = nil -} - -// FindPort returns the host port that is mapped to 'sandboxPort'. This calls -// docker to allocate a free port in the host and prevent conflicts. -func (d *Docker) FindPort(sandboxPort int) (int, error) { - format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort) - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving port: %v", err) - } - port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing port %q: %v", out, err) - } - return port, nil -} - -// FindIP returns the IP address of the container. -func (d *Docker) FindIP() (net.IP, error) { - const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return net.IP{}, fmt.Errorf("error retrieving IP: %v", err) - } - ip := net.ParseIP(strings.TrimSpace(string(out))) - if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", string(out)) - } - return ip, nil -} - -// A NetworkInterface is container's network interface information. -type NetworkInterface struct { - IPv4 net.IP - MAC net.HardwareAddr -} - -// ListNetworks returns the network interfaces of the container, keyed by -// Docker network name. -func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { - const format = `{{json .NetworkSettings.Networks}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) - } - - networks := map[string]map[string]string{} - if err := json.Unmarshal(out, &networks); err != nil { - return nil, fmt.Errorf("error decoding network interfaces: %w", err) - } - - interfaces := map[string]NetworkInterface{} - for name, iface := range networks { - var netface NetworkInterface - - rawIP := strings.TrimSpace(iface["IPAddress"]) - if rawIP != "" { - ip := net.ParseIP(rawIP) - if ip == nil { - return nil, fmt.Errorf("invalid IP: %q", rawIP) - } - // Docker's IPAddress field is IPv4. The IPv6 address - // is stored in the GlobalIPv6Address field. - netface.IPv4 = ip - } - - rawMAC := strings.TrimSpace(iface["MacAddress"]) - if rawMAC != "" { - mac, err := net.ParseMAC(rawMAC) - if err != nil { - return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) - } - netface.MAC = mac - } - - interfaces[name] = netface - } - - return interfaces, nil -} - -// SandboxPid returns the PID to the sandbox process. -func (d *Docker) SandboxPid() (int, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving pid: %v", err) - } - pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing pid %q: %v", out, err) - } - return pid, nil -} - -// ID returns the container ID. -func (d *Docker) ID() (string, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput() - if err != nil { - return "", fmt.Errorf("error retrieving ID: %v", err) - } - return strings.TrimSpace(string(out)), nil -} - -// Wait waits for container to exit, up to the given timeout. Returns error if -// wait fails or timeout is hit. Returns the application return code otherwise. -// Note that the application may have failed even if err == nil, always check -// the exit code. -func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) { - timeoutChan := time.After(timeout) - waitChan := make(chan (syscall.WaitStatus)) - errChan := make(chan (error)) - - go func() { - out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput() - if err != nil { - errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err) - } - exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err) - } - waitChan <- syscall.WaitStatus(uint32(exit)) - }() - - select { - case ws := <-waitChan: - return ws, nil - case err := <-errChan: - return syscall.WaitStatus(1), err - case <-timeoutChan: - return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name) - } -} - -// WaitForOutput calls 'docker logs' to retrieve containers output and searches -// for the given pattern. -func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) { - matches, err := d.WaitForOutputSubmatch(pattern, timeout) - if err != nil { - return "", err - } - if len(matches) == 0 { - return "", nil - } - return matches[0], nil -} - -// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and -// searches for the given pattern. It returns any regexp submatches as well. -func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) { - re := regexp.MustCompile(pattern) - var ( - lastOut string - stopped bool - ) - for exp := time.Now().Add(timeout); time.Now().Before(exp); { - out, err := d.Logs() - if err != nil { - return nil, err - } - if out != lastOut { - if lastOut == "" { - d.logger.Logf("output (start): %s", out) - } else if strings.HasPrefix(out, lastOut) { - d.logger.Logf("output (contn): %s", out[len(lastOut):]) - } else { - d.logger.Logf("output (trunc): %s", out) - } - lastOut = out // Save for future. - if matches := re.FindStringSubmatch(lastOut); matches != nil { - return matches, nil // Success! - } - } else if stopped { - // The sandbox stopped and we looked at the - // logs at least once since determining that. - return nil, fmt.Errorf("no longer running: %v", err) - } else if pid, err := d.SandboxPid(); pid == 0 || err != nil { - // The sandbox may have stopped, but it's - // possible that it has emitted the terminal - // line between the last call to Logs and here. - stopped = true - } - time.Sleep(100 * time.Millisecond) - } - return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut) +// Runtime returns the value of the flag runtime. +func Runtime() string { + return *runtime } diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go new file mode 100644 index 000000000..4c739c9e9 --- /dev/null +++ b/pkg/test/dockerutil/exec.go @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/pkg/stdcopy" +) + +// ExecOpts holds arguments for Exec calls. +type ExecOpts struct { + // Env are additional environment variables. + Env []string + + // Privileged enables privileged mode. + Privileged bool + + // User is the user to use. + User string + + // Enables Tty and stdin for the created process. + UseTTY bool + + // WorkDir is the working directory of the process. + WorkDir string +} + +// Exec creates a process inside the container. +func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) { + p, err := c.doExec(ctx, opts, args) + if err != nil { + return "", err + } + + if exitStatus, err := p.WaitExitStatus(ctx); err != nil { + return "", err + } else if exitStatus != 0 { + out, _ := p.Logs() + return out, fmt.Errorf("process terminated with status: %d", exitStatus) + } + + return p.Logs() +} + +// ExecProcess creates a process inside the container and returns a process struct +// for the caller to use. +func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) { + return c.doExec(ctx, opts, args) +} + +func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) { + config := c.execConfig(r, args) + resp, err := c.client.ContainerExecCreate(ctx, c.id, config) + if err != nil { + return Process{}, fmt.Errorf("exec create failed with err: %v", err) + } + + hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{}) + if err != nil { + return Process{}, fmt.Errorf("exec attach failed with err: %v", err) + } + + if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil { + hijack.Close() + return Process{}, fmt.Errorf("exec start failed with err: %v", err) + } + + return Process{ + container: c, + execid: resp.ID, + conn: hijack, + }, nil +} + +func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig { + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + return types.ExecConfig{ + AttachStdin: r.UseTTY, + AttachStderr: true, + AttachStdout: true, + Cmd: cmd, + Privileged: r.Privileged, + WorkingDir: r.WorkDir, + Env: env, + Tty: r.UseTTY, + User: r.User, + } + +} + +// Process represents a containerized process. +type Process struct { + container *Container + execid string + conn types.HijackedResponse +} + +// Write writes buf to the process's stdin. +func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) { + p.conn.Conn.SetDeadline(time.Now().Add(timeout)) + return p.conn.Conn.Write(buf) +} + +// Read returns process's stdout and stderr. +func (p *Process) Read() (string, string, error) { + var stdout, stderr bytes.Buffer + if err := p.read(&stdout, &stderr); err != nil { + return "", "", err + } + return stdout.String(), stderr.String(), nil +} + +// Logs returns combined stdout/stderr from the process. +func (p *Process) Logs() (string, error) { + var out bytes.Buffer + if err := p.read(&out, &out); err != nil { + return "", err + } + return out.String(), nil +} + +func (p *Process) read(stdout, stderr *bytes.Buffer) error { + _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader) + return err +} + +// ExitCode returns the process's exit code. +func (p *Process) ExitCode(ctx context.Context) (int, error) { + _, exitCode, err := p.runningExitCode(ctx) + return exitCode, err +} + +// IsRunning checks if the process is running. +func (p *Process) IsRunning(ctx context.Context) (bool, error) { + running, _, err := p.runningExitCode(ctx) + return running, err +} + +// WaitExitStatus until process completes and returns exit status. +func (p *Process) WaitExitStatus(ctx context.Context) (int, error) { + waitChan := make(chan (int)) + errChan := make(chan (error)) + + go func() { + for { + running, exitcode, err := p.runningExitCode(ctx) + if err != nil { + errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name) + } + if !running { + waitChan <- exitcode + } + time.Sleep(time.Millisecond * 500) + } + }() + + select { + case ws := <-waitChan: + return ws, nil + case err := <-errChan: + return -1, err + } +} + +// runningExitCode collects if the process is running and the exit code. +// The exit code is only valid if the process has exited. +func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) { + // If execid is not empty, this is a execed process. + if p.execid != "" { + status, err := p.container.client.ContainerExecInspect(ctx, p.execid) + return status.Running, status.ExitCode, err + } + // else this is the root process. + status, err := p.container.Status(ctx) + return status.Running, status.ExitCode, err +} diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go new file mode 100644 index 000000000..047091e75 --- /dev/null +++ b/pkg/test/dockerutil/network.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "net" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Network is a docker network. +type Network struct { + client *client.Client + id string + logger testutil.Logger + Name string + containers []*Container + Subnet *net.IPNet +} + +// NewNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewNetwork(ctx context.Context, logger testutil.Logger) *Network { + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + logger.Logf("create client failed with: %v", err) + return nil + } + client.NegotiateAPIVersion(ctx) + + return &Network{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + client: client, + } +} + +func (n *Network) networkCreate() types.NetworkCreate { + + var subnet string + if n.Subnet != nil { + subnet = n.Subnet.String() + } + + ipam := network.IPAM{ + Config: []network.IPAMConfig{{ + Subnet: subnet, + }}, + } + + return types.NetworkCreate{ + CheckDuplicate: true, + IPAM: &ipam, + } +} + +// Create is analogous to 'docker network create'. +func (n *Network) Create(ctx context.Context) error { + + opts := n.networkCreate() + resp, err := n.client.NetworkCreate(ctx, n.Name, opts) + if err != nil { + return err + } + n.id = resp.ID + return nil +} + +// Connect is analogous to 'docker network connect' with the arguments provided. +func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error { + settings := network.EndpointSettings{ + IPAMConfig: &network.EndpointIPAMConfig{ + IPv4Address: ipv4, + IPv6Address: ipv6, + }, + } + err := n.client.NetworkConnect(ctx, n.id, container.id, &settings) + if err == nil { + n.containers = append(n.containers, container) + } + return err +} + +// Inspect returns this network's info. +func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { + return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *Network) Cleanup(ctx context.Context) error { + for _, c := range n.containers { + c.CleanUp(ctx) + } + n.containers = nil + + return n.client.NetworkRemove(ctx, n.id) +} diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD index 03b1b4677..2d8f56bc0 100644 --- a/pkg/test/testutil/BUILD +++ b/pkg/test/testutil/BUILD @@ -15,6 +15,6 @@ go_library( "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index f21d6769a..64c292698 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -482,6 +482,21 @@ func IsStatic(filename string) (bool, error) { return true, nil } +// TouchShardStatusFile indicates to Bazel that the test runner supports +// sharding by creating or updating the last modified date of the file +// specified by TEST_SHARD_STATUS_FILE. +// +// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner. +func TouchShardStatusFile() error { + if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" { + cmd := exec.Command("touch", statusFile) + if b, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error()) + } + } + return nil +} + // TestIndicesForShard returns indices for this test shard based on the // TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. // |