diff options
Diffstat (limited to 'pkg')
244 files changed, 11525 insertions, 3597 deletions
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go index fb1c679d2..40f0459a0 100644 --- a/pkg/abi/linux/elf.go +++ b/pkg/abi/linux/elf.go @@ -89,3 +89,17 @@ const ( // AT_SYSINFO_EHDR is the address of the VDSO. AT_SYSINFO_EHDR = 33 ) + +// ELF ET_CORE and ptrace GETREGSET/SETREGSET register set types. +// +// See include/uapi/linux/elf.h. +const ( + // NT_PRSTATUS is for general purpose register. + NT_PRSTATUS = 0x1 + + // NT_PRFPREG is for float point register. + NT_PRFPREG = 0x2 + + // NT_X86_XSTATE is for x86 extended state using xsave. + NT_X86_XSTATE = 0x202 +) diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 615e72646..7d742871a 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -178,6 +178,8 @@ const ( // Values for preadv2/pwritev2. const ( + // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is + // accepted as a valid flag argument for preadv2/pwritev2. RWF_HIPRI = 0x00000001 RWF_DSYNC = 0x00000002 RWF_SYNC = 0x00000004 diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index e8b6544b4..0ba086c76 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -122,3 +122,9 @@ const ( NETLINK_EXT_ACK = 11 NETLINK_DUMP_STRICT_CHK = 12 ) + +// NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h. +type NetlinkErrorMessage struct { + Error int32 + Header NetlinkMessageHeader +} diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index dd698e2bc..152f6b144 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -189,3 +189,139 @@ const ( const ( ARPHRD_LOOPBACK = 772 ) + +// RouteMessage struct rtmsg, from uapi/linux/rtnetlink.h. +type RouteMessage struct { + Family uint8 + DstLen uint8 + SrcLen uint8 + TOS uint8 + + Table uint8 + Protocol uint8 + Scope uint8 + Type uint8 + + Flags uint32 +} + +// Route types, from uapi/linux/rtnetlink.h. +const ( + // RTN_UNSPEC represents an unspecified route type. + RTN_UNSPEC = 0 + + // RTN_UNICAST represents a unicast route. + RTN_UNICAST = 1 + + // RTN_LOCAL represents a route that is accepted locally. + RTN_LOCAL = 2 + + // RTN_BROADCAST represents a broadcast route (Traffic is accepted locally + // as broadcast, and sent as broadcast). + RTN_BROADCAST = 3 + + // RTN_ANYCAST represents a anycast route (Traffic is accepted locally as + // broadcast but sent as unicast). + RTN_ANYCAST = 6 + + // RTN_MULTICAST represents a multicast route. + RTN_MULTICAST = 5 + + // RTN_BLACKHOLE represents a route where all traffic is dropped. + RTN_BLACKHOLE = 6 + + // RTN_UNREACHABLE represents a route where the destination is unreachable. + RTN_UNREACHABLE = 7 + + RTN_PROHIBIT = 8 + RTN_THROW = 9 + RTN_NAT = 10 + RTN_XRESOLVE = 11 +) + +// Route protocols/origins, from uapi/linux/rtnetlink.h. +const ( + RTPROT_UNSPEC = 0 + RTPROT_REDIRECT = 1 + RTPROT_KERNEL = 2 + RTPROT_BOOT = 3 + RTPROT_STATIC = 4 + RTPROT_GATED = 8 + RTPROT_RA = 9 + RTPROT_MRT = 10 + RTPROT_ZEBRA = 11 + RTPROT_BIRD = 12 + RTPROT_DNROUTED = 13 + RTPROT_XORP = 14 + RTPROT_NTK = 15 + RTPROT_DHCP = 16 + RTPROT_MROUTED = 17 + RTPROT_BABEL = 42 + RTPROT_BGP = 186 + RTPROT_ISIS = 187 + RTPROT_OSPF = 188 + RTPROT_RIP = 189 + RTPROT_EIGRP = 192 +) + +// Route scopes, from uapi/linux/rtnetlink.h. +const ( + RT_SCOPE_UNIVERSE = 0 + RT_SCOPE_SITE = 200 + RT_SCOPE_LINK = 253 + RT_SCOPE_HOST = 254 + RT_SCOPE_NOWHERE = 255 +) + +// Route flags, from uapi/linux/rtnetlink.h. +const ( + RTM_F_NOTIFY = 0x100 + RTM_F_CLONED = 0x200 + RTM_F_EQUALIZE = 0x400 + RTM_F_PREFIX = 0x800 + RTM_F_LOOKUP_TABLE = 0x1000 + RTM_F_FIB_MATCH = 0x2000 +) + +// Route tables, from uapi/linux/rtnetlink.h. +const ( + RT_TABLE_UNSPEC = 0 + RT_TABLE_COMPAT = 252 + RT_TABLE_DEFAULT = 253 + RT_TABLE_MAIN = 254 + RT_TABLE_LOCAL = 255 +) + +// Route attributes, from uapi/linux/rtnetlink.h. +const ( + RTA_UNSPEC = 0 + RTA_DST = 1 + RTA_SRC = 2 + RTA_IIF = 3 + RTA_OIF = 4 + RTA_GATEWAY = 5 + RTA_PRIORITY = 6 + RTA_PREFSRC = 7 + RTA_METRICS = 8 + RTA_MULTIPATH = 9 + RTA_PROTOINFO = 10 + RTA_FLOW = 11 + RTA_CACHEINFO = 12 + RTA_SESSION = 13 + RTA_MP_ALGO = 14 + RTA_TABLE = 15 + RTA_MARK = 16 + RTA_MFC_STATS = 17 + RTA_VIA = 18 + RTA_NEWDST = 19 + RTA_PREF = 20 + RTA_ENCAP_TYPE = 21 + RTA_ENCAP = 22 + RTA_EXPIRES = 23 + RTA_PAD = 24 + RTA_UID = 25 + RTA_TTL_PROPAGATE = 26 + RTA_IP_PROTO = 27 + RTA_SPORT = 28 + RTA_DPORT = 29 +) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 6d22002c4..d5b731390 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -267,6 +267,20 @@ type SockAddrUnix struct { Path [UnixPathMax]int8 } +// SockAddr represents a union of valid socket address types. This is logically +// equivalent to struct sockaddr. SockAddr ensures that a well-defined set of +// types can be used as socket addresses. +type SockAddr interface { + // implementsSockAddr exists purely to allow a type to indicate that they + // implement this interface. This method is a no-op and shouldn't be called. + implementsSockAddr() +} + +func (s *SockAddrInet) implementsSockAddr() {} +func (s *SockAddrInet6) implementsSockAddr() {} +func (s *SockAddrUnix) implementsSockAddr() {} +func (s *SockAddrNetlink) implementsSockAddr() {} + // Linger is struct linger, from include/linux/socket.h. type Linger struct { OnOff int32 @@ -278,7 +292,10 @@ const SizeOfLinger = 8 // TCPInfo is a collection of TCP statistics. // -// From uapi/linux/tcp.h. +// From uapi/linux/tcp.h. Newer versions of Linux continue to add new fields to +// the end of this struct or within existing unusued space, so its size grows +// over time. The current iteration is based on linux v4.17. New versions are +// always backwards compatible. type TCPInfo struct { State uint8 CaState uint8 @@ -352,7 +369,7 @@ type TCPInfo struct { } // SizeOfTCPInfo is the binary size of a TCPInfo struct. -const SizeOfTCPInfo = 104 +var SizeOfTCPInfo = int(binary.Size(TCPInfo{})) // Control message types, from linux/socket.h. const ( diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 3fabaf445..5d61dc2ff 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -418,6 +418,73 @@ var x86FeatureParseOnlyStrings = map[Feature]string{ X86FeaturePREFETCHWT1: "prefetchwt1", } +// intelCacheDescriptors describe the caches and TLBs on the system. They are +// returned in the registers for eax=2. Intel only. +type intelCacheDescriptor uint8 + +// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol. +// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors". +const ( + intelNullDescriptor intelCacheDescriptor = 0 + intelNoTLBDescriptor intelCacheDescriptor = 0xfe + intelNoCacheDescriptor intelCacheDescriptor = 0xff + + // Most descriptors omitted for brevity as they are currently unused. +) + +// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4. +type CacheType uint8 + +const ( + // cacheNull indicates that there are no more entries. + cacheNull CacheType = iota + + // CacheData is a data cache. + CacheData + + // CacheInstruction is an instruction cache. + CacheInstruction + + // CacheUnified is a unified instruction and data cache. + CacheUnified +) + +// Cache describes the parameters of a single cache on the system. +// +// +stateify savable +type Cache struct { + // Level is the hierarchical level of this cache (L1, L2, etc). + Level uint32 + + // Type is the type of cache. + Type CacheType + + // FullyAssociative indicates that entries may be placed in any block. + FullyAssociative bool + + // Partitions is the number of physical partitions in the cache. + Partitions uint32 + + // Ways is the number of ways of associativity in the cache. + Ways uint32 + + // Sets is the number of sets in the cache. + Sets uint32 + + // InvalidateHierarchical indicates that WBINVD/INVD from threads + // sharing this cache acts upon lower level caches for threads sharing + // this cache. + InvalidateHierarchical bool + + // Inclusive indicates that this cache is inclusive of lower cache + // levels. + Inclusive bool + + // DirectMapped indicates that this cache is directly mapped from + // address, rather than using a hash function. + DirectMapped bool +} + // Just a way to wrap cpuid function numbers. type cpuidFunction uint32 @@ -494,7 +561,7 @@ func (f Feature) flagString(cpuinfoOnly bool) string { return "" } -// FeatureSet is a set of Features for a cpu. +// FeatureSet is a set of Features for a CPU. // // +stateify savable type FeatureSet struct { @@ -521,6 +588,15 @@ type FeatureSet struct { // SteppingID is part of the processor signature. SteppingID uint8 + + // Caches describes the caches on the CPU. + Caches []Cache + + // CacheLine is the size of a cache line in bytes. + // + // All caches use the same line size. This is not enforced in the CPUID + // encoding, but is true on all known x86 processors. + CacheLine uint32 } // FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is @@ -557,22 +633,27 @@ func (fs FeatureSet) CPUInfo(cpu uint) string { fmt.Fprintln(&b, "wp\t\t: yes") fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true)) fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway. - fmt.Fprintf(&b, "clflush size\t: %d\n", 64) - fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64) + fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine) + fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine) fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48) fmt.Fprintln(&b, "power management:") // This is always here, but can be blank. fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline. return b.String() } +const ( + amdVendorID = "AuthenticAMD" + intelVendorID = "GenuineIntel" +) + // AMD returns true if fs describes an AMD CPU. func (fs *FeatureSet) AMD() bool { - return fs.VendorID == "AuthenticAMD" + return fs.VendorID == amdVendorID } // Intel returns true if fs describes an Intel CPU. func (fs *FeatureSet) Intel() bool { - return fs.VendorID == "GenuineIntel" + return fs.VendorID == intelVendorID } // ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a @@ -589,9 +670,18 @@ func (e ErrIncompatible) Error() string { // CheckHostCompatible returns nil if fs is a subset of the host feature set. func (fs *FeatureSet) CheckHostCompatible() error { hfs := HostFeatureSet() + if diff := fs.Subtract(hfs); diff != nil { return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)} } + + // The size of a cache line must match, as it is critical to correctly + // utilizing CLFLUSH. Other cache properties are allowed to change, as + // they are not important to correctness. + if fs.CacheLine != hfs.CacheLine { + return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)} + } + return nil } @@ -732,14 +822,6 @@ func (fs *FeatureSet) HasFeature(feature Feature) bool { return fs.Set[feature] } -// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in. -// This is useful if you want to see if a FeatureSet is compatible with another -// FeatureSet, since you can only run with a given FeatureSet if it's a subset of -// the host's. -func (fs *FeatureSet) IsSubset(other *FeatureSet) bool { - return fs.Subtract(other) == nil -} - // Subtract returns the features present in fs that are not present in other. // If all features in fs are present in other, Subtract returns nil. func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) { @@ -755,17 +837,6 @@ func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) { return } -// TakeFeatureIntersection will set the features in `fs` to the intersection of -// the features in `fs` and `other` (effectively clearing any feature bits on -// `fs` that are not also set in `other`). -func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) { - for f := range fs.Set { - if !other.Set[f] { - delete(fs.Set, f) - } - } -} - // EmulateID emulates a cpuid instruction based on the feature set. func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) { switch cpuidFunction(origAx) { @@ -773,9 +844,8 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) { ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support. bx, dx, cx = fs.vendorIDRegs() case featureInfo: - // clflush line size (ebx bits[15:8]) hardcoded as 8. This - // means cache lines of size 64 bytes. - bx = 8 << 8 + // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported. + bx = (fs.CacheLine / 8) << 8 cx = fs.blockMask(block(0)) dx = fs.blockMask(block(1)) ax = fs.signature() @@ -789,10 +859,46 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) { // will always return 01H. Software should ignore this value // and not interpret it as an informational descriptor." - SDM // - // We do not support exposing cache information, but we do set - // this fixed field because some language runtimes (dlang) get - // confused by ax = 0 and will loop infinitely. - ax = 1 + // We only support reporting cache parameters via + // intelDeterministicCacheParams; report as much here. + // + // We do not support exposing TLB information at all. + ax = 1 | (uint32(intelNoCacheDescriptor) << 8) + case intelDeterministicCacheParams: + if !fs.Intel() { + // Reserved on non-Intel. + return 0, 0, 0, 0 + } + + // cx is the index of the cache to describe. + if int(origCx) >= len(fs.Caches) { + return uint32(cacheNull), 0, 0, 0 + } + c := fs.Caches[origCx] + + ax = uint32(c.Type) + ax |= c.Level << 5 + ax |= 1 << 8 // Always claim the cache is "self-initializing". + if c.FullyAssociative { + ax |= 1 << 9 + } + // Processor topology not supported. + + bx = fs.CacheLine - 1 + bx |= (c.Partitions - 1) << 12 + bx |= (c.Ways - 1) << 22 + + cx = c.Sets - 1 + + if !c.InvalidateHierarchical { + dx |= 1 + } + if c.Inclusive { + dx |= 1 << 1 + } + if !c.DirectMapped { + dx |= 1 << 2 + } case xSaveInfo: if !fs.UseXsave() { return 0, 0, 0, 0 @@ -845,10 +951,41 @@ func HostFeatureSet() *FeatureSet { vendorID := vendorIDFromRegs(bx, cx, dx) // eax=1 gets basic features in ecx:edx. - ax, _, cx, dx := HostID(1, 0) + ax, bx, cx, dx := HostID(1, 0) featureBlock0 := cx featureBlock1 := dx ef, em, pt, f, m, sid := signatureSplit(ax) + cacheLine := 8 * (bx >> 8) & 0xff + + // eax=4, ecx=i gets details about cache index i. Only supported on Intel. + var caches []Cache + if vendorID == intelVendorID { + // ecx selects the cache index until a null type is returned. + for i := uint32(0); ; i++ { + ax, bx, cx, dx := HostID(4, i) + t := CacheType(ax & 0xf) + if t == cacheNull { + break + } + + lineSize := (bx & 0xfff) + 1 + if lineSize != cacheLine { + panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine)) + } + + caches = append(caches, Cache{ + Type: t, + Level: (ax >> 5) & 0x7, + FullyAssociative: ((ax >> 9) & 1) == 1, + Partitions: ((bx >> 12) & 0x3ff) + 1, + Ways: ((bx >> 22) & 0x3ff) + 1, + Sets: cx + 1, + InvalidateHierarchical: (dx & 1) == 0, + Inclusive: ((dx >> 1) & 1) == 1, + DirectMapped: ((dx >> 2) & 1) == 0, + }) + } + } // eax=7, ecx=0 gets extended features in ecx:ebx. _, bx, cx, _ = HostID(7, 0) @@ -883,6 +1020,8 @@ func HostFeatureSet() *FeatureSet { Family: f, Model: m, SteppingID: sid, + CacheLine: cacheLine, + Caches: caches, } } diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go index 6ae14d2da..a707ebb55 100644 --- a/pkg/cpuid/cpuid_test.go +++ b/pkg/cpuid/cpuid_test.go @@ -57,24 +57,13 @@ var justFPUandPAE = &FeatureSet{ X86FeaturePAE: true, }} -func TestIsSubset(t *testing.T) { - if !justFPU.IsSubset(justFPUandPAE) { - t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE) +func TestSubtract(t *testing.T) { + if diff := justFPU.Subtract(justFPUandPAE); diff != nil { + t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff) } - if justFPUandPAE.IsSubset(justFPU) { - t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE) - } -} - -func TestTakeFeatureIntersection(t *testing.T) { - testFeatures := HostFeatureSet() - testFeatures.TakeFeatureIntersection(justFPU) - if !testFeatures.IsSubset(justFPU) { - t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set) - } - if !testFeatures.HasFeature(X86FeatureFPU) { - t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU) + if justFPUandPAE.Subtract(justFPU) == nil { + t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE) } } @@ -83,7 +72,7 @@ func TestTakeFeatureIntersection(t *testing.T) { // if HostFeatureSet gives back junk bits. func TestHostFeatureSet(t *testing.T) { hostFeatures := HostFeatureSet() - if !justFPUandPAE.IsSubset(hostFeatures) { + if justFPUandPAE.Subtract(hostFeatures) != nil { t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) } } @@ -175,6 +164,7 @@ func TestEmulateIDBasicFeatures(t *testing.T) { testFeatures := newEmptyFeatureSet() testFeatures.Add(X86FeatureCLFSH) testFeatures.Add(X86FeatureAVX) + testFeatures.CacheLine = 64 ax, bx, cx, dx := testFeatures.EmulateID(1, 0) ECXAVXBit := uint32(1 << uint(X86FeatureAVX)) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 4c336ea84..9961baaa9 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools/go_stateify:defs.bzl", "go_library", "go_test") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") package(licenses = ["notice"]) @@ -7,6 +7,7 @@ go_library( name = "eventchannel", srcs = [ "event.go", + "rate.go", ], importpath = "gvisor.dev/gvisor/pkg/eventchannel", visibility = ["//:sandbox"], @@ -16,6 +17,7 @@ go_library( "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library_gen", + "@org_golang_x_time//rate:go_default_library", ], ) @@ -30,3 +32,12 @@ go_proto_library( proto = ":eventchannel_proto", visibility = ["//:sandbox"], ) + +go_test( + name = "eventchannel_test", + srcs = ["event_test.go"], + embed = [":eventchannel"], + deps = [ + "@com_github_golang_protobuf//proto:go_default_library", + ], +) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index f6d26532b..d37ad0428 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -43,18 +43,36 @@ type Emitter interface { Close() error } -var ( - mu sync.Mutex - emitters = make(map[Emitter]struct{}) -) +// DefaultEmitter is the default emitter. Calls to Emit and AddEmitter are sent +// to this Emitter. +var DefaultEmitter = &multiEmitter{} -// Emit emits a message using all added emitters. +// Emit is a helper method that calls DefaultEmitter.Emit. func Emit(msg proto.Message) error { - mu.Lock() - defer mu.Unlock() + _, err := DefaultEmitter.Emit(msg) + return err +} + +// AddEmitter is a helper method that calls DefaultEmitter.AddEmitter. +func AddEmitter(e Emitter) { + DefaultEmitter.AddEmitter(e) +} + +// multiEmitter is an Emitter that forwards messages to multiple Emitters. +type multiEmitter struct { + // mu protects emitters. + mu sync.Mutex + // emitters is initialized lazily in AddEmitter. + emitters map[Emitter]struct{} +} + +// Emit emits a message using all added emitters. +func (me *multiEmitter) Emit(msg proto.Message) (bool, error) { + me.mu.Lock() + defer me.mu.Unlock() var err error - for e := range emitters { + for e := range me.emitters { hangup, eerr := e.Emit(msg) if eerr != nil { if err == nil { @@ -68,18 +86,36 @@ func Emit(msg proto.Message) error { } if hangup { log.Infof("Hangup on eventchannel emitter %v.", e) - delete(emitters, e) + delete(me.emitters, e) } } - return err + return false, err } // AddEmitter adds a new emitter. -func AddEmitter(e Emitter) { - mu.Lock() - defer mu.Unlock() - emitters[e] = struct{}{} +func (me *multiEmitter) AddEmitter(e Emitter) { + me.mu.Lock() + defer me.mu.Unlock() + if me.emitters == nil { + me.emitters = make(map[Emitter]struct{}) + } + me.emitters[e] = struct{}{} +} + +// Close closes all emitters. If any Close call errors, it returns the first +// one encountered. +func (me *multiEmitter) Close() error { + me.mu.Lock() + defer me.mu.Unlock() + var err error + for e := range me.emitters { + if eerr := e.Close(); err == nil && eerr != nil { + err = eerr + } + delete(me.emitters, e) + } + return err } func marshal(msg proto.Message) ([]byte, error) { diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go new file mode 100644 index 000000000..3649097d6 --- /dev/null +++ b/pkg/eventchannel/event_test.go @@ -0,0 +1,146 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventchannel + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" +) + +// testEmitter is an emitter that can be used in tests. It records all events +// emitted, and whether it has been closed. +type testEmitter struct { + // mu protects all fields below. + mu sync.Mutex + + // events contains all emitted events. + events []proto.Message + + // closed records whether Close() was called. + closed bool +} + +// Emit implements Emitter.Emit. +func (te *testEmitter) Emit(msg proto.Message) (bool, error) { + te.mu.Lock() + defer te.mu.Unlock() + te.events = append(te.events, msg) + return false, nil +} + +// Close implements Emitter.Close. +func (te *testEmitter) Close() error { + te.mu.Lock() + defer te.mu.Unlock() + if te.closed { + return fmt.Errorf("closed called twice") + } + te.closed = true + return nil +} + +// testMessage implements proto.Message for testing. +type testMessage struct { + proto.Message + + // name is the name of the message, used by tests to compare messages. + name string +} + +func TestMultiEmitter(t *testing.T) { + // Create three testEmitters, tied together in a multiEmitter. + me := &multiEmitter{} + var emitters []*testEmitter + for i := 0; i < 3; i++ { + te := &testEmitter{} + emitters = append(emitters, te) + me.AddEmitter(te) + } + + // Emit three messages to multiEmitter. + names := []string{"foo", "bar", "baz"} + for _, name := range names { + m := testMessage{name: name} + if _, err := me.Emit(m); err != nil { + t.Fatal("me.Emit(%v) failed: %v", m, err) + } + } + + // All three emitters should have all three events. + for _, te := range emitters { + if got, want := len(te.events), len(names); got != want { + t.Fatalf("emitter got %d events, want %d", got, want) + } + for i, name := range names { + if got := te.events[i].(testMessage).name; got != name { + t.Errorf("emitter got message with name %q, want %q", got, name) + } + } + } + + // Close multiEmitter. + if err := me.Close(); err != nil { + t.Fatal("me.Close() failed: %v", err) + } + + // All testEmitters should be closed. + for _, te := range emitters { + if !te.closed { + t.Errorf("te.closed got false, want true") + } + } +} + +func TestRateLimitedEmitter(t *testing.T) { + // Create a RateLimittedEmitter that wraps a testEmitter. + te := &testEmitter{} + max := float64(5) // events per second + burst := 10 // events + rle := RateLimitedEmitterFrom(te, max, burst) + + // Send 50 messages in one shot. + for i := 0; i < 50; i++ { + if _, err := rle.Emit(testMessage{}); err != nil { + t.Fatalf("rle.Emit failed: %v", err) + } + } + + // We should have received only 10 messages. + if got, want := len(te.events), 10; got != want { + t.Errorf("got %d events, want %d", got, want) + } + + // Sleep for a second and then send another 50. + time.Sleep(1 * time.Second) + for i := 0; i < 50; i++ { + if _, err := rle.Emit(testMessage{}); err != nil { + t.Fatalf("rle.Emit failed: %v", err) + } + } + + // We should have at least 5 more message, plus maybe a few more if the + // test ran slowly. + got, wantAtLeast, wantAtMost := len(te.events), 15, 20 + if got < wantAtLeast { + t.Errorf("got %d events, want at least %d", got, wantAtLeast) + } + if got > wantAtMost { + t.Errorf("got %d events, want at most %d", got, wantAtMost) + } +} diff --git a/pkg/eventchannel/rate.go b/pkg/eventchannel/rate.go new file mode 100644 index 000000000..179226c92 --- /dev/null +++ b/pkg/eventchannel/rate.go @@ -0,0 +1,54 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventchannel + +import ( + "github.com/golang/protobuf/proto" + "golang.org/x/time/rate" +) + +// rateLimitedEmitter wraps an emitter and limits events to the given limits. +// Events that would exceed the limit are discarded. +type rateLimitedEmitter struct { + inner Emitter + limiter *rate.Limiter +} + +// RateLimitedEmitterFrom creates a new event channel emitter that wraps the +// existing emitter and enforces rate limits. The limits are imposed via a +// token bucket, with `maxRate` events per second, with burst size of `burst` +// events. See the golang.org/x/time/rate package and +// https://en.wikipedia.org/wiki/Token_bucket for more information about token +// buckets generally. +func RateLimitedEmitterFrom(inner Emitter, maxRate float64, burst int) Emitter { + return &rateLimitedEmitter{ + inner: inner, + limiter: rate.NewLimiter(rate.Limit(maxRate), burst), + } +} + +// Emit implements EventEmitter.Emit. +func (rle *rateLimitedEmitter) Emit(msg proto.Message) (bool, error) { + if !rle.limiter.Allow() { + // Drop event. + return false, nil + } + return rle.inner.Emit(msg) +} + +// Close implements EventEmitter.Close. +func (rle *rateLimitedEmitter) Close() error { + return rle.inner.Close() +} diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index d0552c06e..aca2d8a82 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -10,5 +10,8 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], - deps = ["//pkg/waiter"], + deps = [ + "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go index 58529f99f..f4aae1953 100644 --- a/pkg/fdnotifier/fdnotifier.go +++ b/pkg/fdnotifier/fdnotifier.go @@ -25,6 +25,7 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/waiter" ) @@ -72,7 +73,7 @@ func (n *notifier) waitFD(fd int32, fi *fdInfo, mask waiter.EventMask) error { } e := syscall.EpollEvent{ - Events: mask.ToLinux() | -syscall.EPOLLET, + Events: mask.ToLinux() | unix.EPOLLET, Fd: fd, } diff --git a/pkg/fdnotifier/poll_unsafe.go b/pkg/fdnotifier/poll_unsafe.go index ab8857b5e..4225b04dd 100644 --- a/pkg/fdnotifier/poll_unsafe.go +++ b/pkg/fdnotifier/poll_unsafe.go @@ -35,8 +35,14 @@ func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask { events: int16(mask.ToLinux()), } + ts := syscall.Timespec{ + Sec: 0, + Nsec: 0, + } + for { - n, _, err := syscall.RawSyscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&e)), 1, 0) + n, _, err := syscall.RawSyscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(&e)), 1, + uintptr(unsafe.Pointer(&ts)), 0, 0, 0) // Interrupted by signal, try again. if err == syscall.EINTR { continue diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 7126fc45f..bd1d614b6 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -5,10 +5,11 @@ package(licenses = ["notice"]) go_library( name = "flipcall", srcs = [ - "endpoint_futex.go", - "endpoint_unsafe.go", + "ctrl_futex.go", "flipcall.go", + "flipcall_unsafe.go", "futex_linux.go", + "io.go", "packet_window_allocator.go", ], importpath = "gvisor.dev/gvisor/pkg/flipcall", diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go new file mode 100644 index 000000000..865b6f640 --- /dev/null +++ b/pkg/flipcall/ctrl_futex.go @@ -0,0 +1,146 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "encoding/json" + "fmt" + "math" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/log" +) + +type endpointControlImpl struct { + state int32 +} + +// Bits in endpointControlImpl.state. +const ( + epsBlocked = 1 << iota + epsShutdown +) + +func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error { + if len(opts) != 0 { + return fmt.Errorf("unknown EndpointOption: %T", opts[0]) + } + return nil +} + +type ctrlHandshakeRequest struct{} + +type ctrlHandshakeResponse struct{} + +func (ep *Endpoint) ctrlConnect() error { + if err := ep.enterFutexWait(); err != nil { + return err + } + _, err := ep.futexConnect(&ctrlHandshakeRequest{}) + ep.exitFutexWait() + return err +} + +func (ep *Endpoint) ctrlWaitFirst() error { + if err := ep.enterFutexWait(); err != nil { + return err + } + defer ep.exitFutexWait() + + // Wait for the handshake request. + if err := ep.futexSwitchFromPeer(); err != nil { + return err + } + + // Read the handshake request. + reqLen := atomic.LoadUint32(ep.dataLen()) + if reqLen > ep.dataCap { + return fmt.Errorf("invalid handshake request length %d (maximum %d)", reqLen, ep.dataCap) + } + var req ctrlHandshakeRequest + if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil { + return fmt.Errorf("error reading handshake request: %v", err) + } + + // Write the handshake response. + w := ep.NewWriter() + if err := json.NewEncoder(w).Encode(ctrlHandshakeResponse{}); err != nil { + return fmt.Errorf("error writing handshake response: %v", err) + } + *ep.dataLen() = w.Len() + + // Return control to the client. + if err := ep.futexSwitchToPeer(); err != nil { + return err + } + + // Wait for the first non-handshake message. + return ep.futexSwitchFromPeer() +} + +func (ep *Endpoint) ctrlRoundTrip() error { + if err := ep.futexSwitchToPeer(); err != nil { + return err + } + if err := ep.enterFutexWait(); err != nil { + return err + } + err := ep.futexSwitchFromPeer() + ep.exitFutexWait() + return err +} + +func (ep *Endpoint) ctrlWakeLast() error { + return ep.futexSwitchToPeer() +} + +func (ep *Endpoint) enterFutexWait() error { + switch eps := atomic.AddInt32(&ep.ctrl.state, epsBlocked); eps { + case epsBlocked: + return nil + case epsBlocked | epsShutdown: + atomic.AddInt32(&ep.ctrl.state, -epsBlocked) + return shutdownError{} + default: + // Most likely due to ep.enterFutexWait() being called concurrently + // from multiple goroutines. + panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state before flipcall.Endpoint.enterFutexWait(): %v", eps-epsBlocked)) + } +} + +func (ep *Endpoint) exitFutexWait() { + atomic.AddInt32(&ep.ctrl.state, -epsBlocked) +} + +func (ep *Endpoint) ctrlShutdown() { + // Set epsShutdown to ensure that future calls to ep.enterFutexWait() fail. + if atomic.AddInt32(&ep.ctrl.state, epsShutdown)&epsBlocked != 0 { + // Wake the blocked thread. This must loop because it's possible that + // FUTEX_WAKE occurs after the waiter sets epsBlocked, but before it + // blocks in FUTEX_WAIT. + for { + // Wake MaxInt32 threads to prevent a broken or malicious peer from + // swallowing our wakeup by FUTEX_WAITing from multiple threads. + if err := ep.futexWakeConnState(math.MaxInt32); err != nil { + log.Warningf("failed to FUTEX_WAKE Endpoints: %v", err) + break + } + yieldThread() + if atomic.LoadInt32(&ep.ctrl.state)&epsBlocked == 0 { + break + } + } + } +} diff --git a/pkg/flipcall/endpoint_unsafe.go b/pkg/flipcall/endpoint_unsafe.go deleted file mode 100644 index 8319955e0..000000000 --- a/pkg/flipcall/endpoint_unsafe.go +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flipcall - -import ( - "fmt" - "math" - "reflect" - "sync/atomic" - "syscall" - "unsafe" -) - -// An Endpoint provides the ability to synchronously transfer data and control -// to a connected peer Endpoint, which may be in another process. -// -// Since the Endpoint control transfer model is synchronous, at any given time -// one Endpoint "has control" (designated the *active* Endpoint), and the other -// is "waiting for control" (designated the *inactive* Endpoint). Users of the -// flipcall package arbitrarily designate one Endpoint as initially-active, and -// the other as initially-inactive; in a client/server protocol, the client -// Endpoint is usually initially-active (able to send a request) and the server -// Endpoint is usually initially-inactive (waiting for a request). The -// initially-active Endpoint writes data to be sent to Endpoint.Data(), and -// then synchronously transfers control to the inactive Endpoint by calling -// Endpoint.SendRecv(), becoming the inactive Endpoint in the process. The -// initially-inactive Endpoint waits for control by calling -// Endpoint.RecvFirst(); receiving control causes it to become the active -// Endpoint. After this, the protocol is symmetric: the active Endpoint reads -// data sent by the peer by reading from Endpoint.Data(), writes data to be -// sent to the peer into Endpoint.Data(), and then calls Endpoint.SendRecv() to -// exchange roles with the peer, which blocks until the peer has done the same. -type Endpoint struct { - // shutdown is non-zero if Endpoint.Shutdown() has been called. shutdown is - // accessed using atomic memory operations. - shutdown uint32 - - // dataCap is the size of the datagram part of the packet window in bytes. - // dataCap is immutable. - dataCap uint32 - - // packet is the beginning of the packet window. packet is immutable. - packet unsafe.Pointer - - ctrl endpointControlState -} - -// Init must be called on zero-value Endpoints before first use. If it -// succeeds, Destroy() must be called once the Endpoint is no longer in use. -// -// ctrlMode specifies how connected Endpoints will exchange control. Both -// connected Endpoints must specify the same value for ctrlMode. -// -// pwd represents the packet window used to exchange data with the peer -// Endpoint. FD may differ between Endpoints if they are in different -// processes, but must represent the same file. The packet window must -// initially be filled with zero bytes. -func (ep *Endpoint) Init(ctrlMode ControlMode, pwd PacketWindowDescriptor) error { - if pwd.Length < pageSize { - return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize) - } - if pwd.Length > math.MaxUint32 { - return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32) - } - m, _, e := syscall.Syscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) - if e != 0 { - return fmt.Errorf("failed to mmap packet window: %v", e) - } - ep.dataCap = uint32(pwd.Length) - uint32(packetHeaderBytes) - ep.packet = (unsafe.Pointer)(m) - if err := ep.initControlState(ctrlMode); err != nil { - ep.unmapPacket() - return err - } - return nil -} - -// NewEndpoint is a convenience function that returns an initialized Endpoint -// allocated on the heap. -func NewEndpoint(ctrlMode ControlMode, pwd PacketWindowDescriptor) (*Endpoint, error) { - var ep Endpoint - if err := ep.Init(ctrlMode, pwd); err != nil { - return nil, err - } - return &ep, nil -} - -func (ep *Endpoint) unmapPacket() { - syscall.Syscall(syscall.SYS_MUNMAP, uintptr(ep.packet), uintptr(ep.dataCap)+packetHeaderBytes, 0) - ep.dataCap = 0 - ep.packet = nil -} - -// Destroy releases resources owned by ep. No other Endpoint methods may be -// called after Destroy. -func (ep *Endpoint) Destroy() { - ep.unmapPacket() -} - -// Packets consist of an 8-byte header followed by an arbitrarily-sized -// datagram. The header consists of: -// -// - A 4-byte native-endian sequence number, which is incremented by the active -// Endpoint after it finishes writing to the packet window. The sequence number -// is needed to handle spurious wakeups. -// -// - A 4-byte native-endian datagram length in bytes. -const ( - sizeofUint32 = unsafe.Sizeof(uint32(0)) - packetHeaderBytes = 2 * sizeofUint32 -) - -func (ep *Endpoint) seq() *uint32 { - return (*uint32)(ep.packet) -} - -func (ep *Endpoint) dataLen() *uint32 { - return (*uint32)((unsafe.Pointer)(uintptr(ep.packet) + sizeofUint32)) -} - -// DataCap returns the maximum datagram size supported by ep in bytes. -func (ep *Endpoint) DataCap() uint32 { - return ep.dataCap -} - -func (ep *Endpoint) data() unsafe.Pointer { - return unsafe.Pointer(uintptr(ep.packet) + packetHeaderBytes) -} - -// Data returns the datagram part of ep's packet window as a byte slice. -// -// Note that the packet window is shared with the potentially-untrusted peer -// Endpoint, which may concurrently mutate the contents of the packet window. -// Thus: -// -// - Readers must not assume that two reads of the same byte in Data() will -// return the same result. In other words, readers should read any given byte -// in Data() at most once. -// -// - Writers must not assume that they will read back the same data that they -// have written. In other words, writers should avoid reading from Data() at -// all. -func (ep *Endpoint) Data() []byte { - var bs []byte - bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs)) - bsReflect.Data = uintptr(ep.data()) - bsReflect.Len = int(ep.DataCap()) - bsReflect.Cap = bsReflect.Len - return bs -} - -// SendRecv transfers control to the peer Endpoint, causing its call to -// Endpoint.SendRecv() or Endpoint.RecvFirst() to return with the given -// datagram length, then blocks until the peer Endpoint calls -// Endpoint.SendRecv() or Endpoint.SendLast(). -// -// Preconditions: No previous call to ep.SendRecv() or ep.RecvFirst() has -// returned an error. ep.SendLast() has never been called. -func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { - dataCap := ep.DataCap() - if dataLen > dataCap { - return 0, fmt.Errorf("can't send packet with datagram length %d (maximum %d)", dataLen, dataCap) - } - atomic.StoreUint32(ep.dataLen(), dataLen) - if err := ep.doRoundTrip(); err != nil { - return 0, err - } - recvDataLen := atomic.LoadUint32(ep.dataLen()) - if recvDataLen > dataCap { - return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, dataCap) - } - return recvDataLen, nil -} - -// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then -// returns the datagram length specified by that call. -// -// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never -// been called. -func (ep *Endpoint) RecvFirst() (uint32, error) { - if err := ep.doWaitFirst(); err != nil { - return 0, err - } - recvDataLen := atomic.LoadUint32(ep.dataLen()) - if dataCap := ep.DataCap(); recvDataLen > dataCap { - return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, dataCap) - } - return recvDataLen, nil -} - -// SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or -// Endpoint.RecvFirst() to return with the given datagram length. -// -// Preconditions: No previous call to ep.SendRecv() or ep.RecvFirst() has -// returned an error. ep.SendLast() has never been called. -func (ep *Endpoint) SendLast(dataLen uint32) error { - dataCap := ep.DataCap() - if dataLen > dataCap { - return fmt.Errorf("can't send packet with datagram length %d (maximum %d)", dataLen, dataCap) - } - atomic.StoreUint32(ep.dataLen(), dataLen) - if err := ep.doNotifyLast(); err != nil { - return err - } - return nil -} - -// Shutdown causes concurrent and future calls to ep.SendRecv(), -// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not -// wait for concurrent calls to return. -func (ep *Endpoint) Shutdown() { - if atomic.SwapUint32(&ep.shutdown, 1) == 0 { - ep.interruptForShutdown() - } -} - -func (ep *Endpoint) isShutdown() bool { - return atomic.LoadUint32(&ep.shutdown) != 0 -} - -type endpointShutdownError struct{} - -// Error implements error.Error. -func (endpointShutdownError) Error() string { - return "Endpoint.Shutdown() has been called" -} diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 79a1e418a..5c9212c33 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -13,20 +13,217 @@ // limitations under the License. // Package flipcall implements a protocol providing Fast Local Interprocess -// Procedure Calls. +// Procedure Calls between mutually-distrusting processes. package flipcall -// ControlMode defines how control is exchanged across a connection. -type ControlMode uint8 +import ( + "fmt" + "math" + "sync/atomic" + "syscall" +) -const ( - // ControlModeInvalid is invalid, and exists so that ControlMode fields in - // structs must be explicitly initialized. - ControlModeInvalid ControlMode = iota +// An Endpoint provides the ability to synchronously transfer data and control +// to a connected peer Endpoint, which may be in another process. +// +// Since the Endpoint control transfer model is synchronous, at any given time +// one Endpoint "has control" (designated the active Endpoint), and the other +// is "waiting for control" (designated the inactive Endpoint). Users of the +// flipcall package designate one Endpoint as the client, which is initially +// active, and the other as the server, which is initially inactive. See +// flipcall_example_test.go for usage. +type Endpoint struct { + // packet is a pointer to the beginning of the packet window. (Since this + // is a raw OS memory mapping and not a Go object, it does not need to be + // represented as an unsafe.Pointer.) packet is immutable. + packet uintptr + + // dataCap is the size of the datagram part of the packet window in bytes. + // dataCap is immutable. + dataCap uint32 + + // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the + // Endpoint has acknowledged shutdown initiated by the peer. shutdown is + // accessed using atomic memory operations. + shutdown uint32 + + // activeState is csClientActive if this is a client Endpoint and + // csServerActive if this is a server Endpoint. + activeState uint32 + + // inactiveState is csServerActive if this is a client Endpoint and + // csClientActive if this is a server Endpoint. + inactiveState uint32 + + ctrl endpointControlImpl +} + +// Init must be called on zero-value Endpoints before first use. If it +// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use. +// +// pwd represents the packet window used to exchange data with the peer +// Endpoint. FD may differ between Endpoints if they are in different +// processes, but must represent the same file. The packet window must +// initially be filled with zero bytes. +func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error { + if pwd.Length < pageSize { + return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize) + } + if pwd.Length > math.MaxUint32 { + return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32) + } + m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + if e != 0 { + return fmt.Errorf("failed to mmap packet window: %v", e) + } + ep.packet = m + ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) + // These will be overwritten by ep.Connect() for client Endpoints. + ep.activeState = csServerActive + ep.inactiveState = csClientActive + if err := ep.ctrlInit(opts...); err != nil { + ep.unmapPacket() + return err + } + return nil +} + +// NewEndpoint is a convenience function that returns an initialized Endpoint +// allocated on the heap. +func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) { + var ep Endpoint + if err := ep.Init(pwd, opts...); err != nil { + return nil, err + } + return &ep, nil +} + +// An EndpointOption configures an Endpoint. +type EndpointOption interface { + isEndpointOption() +} + +// Destroy releases resources owned by ep. No other Endpoint methods may be +// called after Destroy. +func (ep *Endpoint) Destroy() { + ep.unmapPacket() +} + +func (ep *Endpoint) unmapPacket() { + syscall.RawSyscall(syscall.SYS_MUNMAP, ep.packet, uintptr(ep.dataCap)+PacketHeaderBytes, 0) + ep.packet = 0 +} - // ControlModeFutex uses shared futex operations on packet control words. - ControlModeFutex +// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), +// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not +// wait for concurrent calls to return. The effect of Shutdown on the peer +// Endpoint is unspecified. Successive calls to Shutdown have no effect. +// +// Shutdown is the only Endpoint method that may be called concurrently with +// other methods on the same Endpoint. +func (ep *Endpoint) Shutdown() { + if atomic.SwapUint32(&ep.shutdown, 1) != 0 { + // ep.Shutdown() has previously been called. + return + } + ep.ctrlShutdown() +} + +// isShutdownLocally returns true if ep.Shutdown() has been called. +func (ep *Endpoint) isShutdownLocally() bool { + return atomic.LoadUint32(&ep.shutdown) != 0 +} + +type shutdownError struct{} + +// Error implements error.Error. +func (shutdownError) Error() string { + return "flipcall connection shutdown" +} - // controlModeCount is the number of ControlModes in this list. - controlModeCount +// DataCap returns the maximum datagram size supported by ep. Equivalently, +// DataCap returns len(ep.Data()). +func (ep *Endpoint) DataCap() uint32 { + return ep.dataCap +} + +// Connection state. +const ( + // The client is, by definition, initially active, so this must be 0. + csClientActive = 0 + csServerActive = 1 ) + +// Connect designates ep as a client Endpoint and blocks until the peer +// Endpoint has called Endpoint.RecvFirst(). +// +// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and +// ep.SendLast() have never been called. +func (ep *Endpoint) Connect() error { + ep.activeState = csClientActive + ep.inactiveState = csServerActive + return ep.ctrlConnect() +} + +// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then +// returns the datagram length specified by that call. +// +// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never +// been called. +func (ep *Endpoint) RecvFirst() (uint32, error) { + if err := ep.ctrlWaitFirst(); err != nil { + return 0, err + } + recvDataLen := atomic.LoadUint32(ep.dataLen()) + if recvDataLen > ep.dataCap { + return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) + } + return recvDataLen, nil +} + +// SendRecv transfers control to the peer Endpoint, causing its call to +// Endpoint.SendRecv() or Endpoint.RecvFirst() to return with the given +// datagram length, then blocks until the peer Endpoint calls +// Endpoint.SendRecv() or Endpoint.SendLast(). +// +// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or +// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. +// If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. +func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { + if dataLen > ep.dataCap { + panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) + } + // This store can safely be non-atomic: Under correct operation we should + // be the only thread writing ep.dataLen(), and ep.ctrlRoundTrip() will + // synchronize with the receiver. We will not read from ep.dataLen() until + // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then + // they can only shoot themselves in the foot. + *ep.dataLen() = dataLen + if err := ep.ctrlRoundTrip(); err != nil { + return 0, err + } + recvDataLen := atomic.LoadUint32(ep.dataLen()) + if recvDataLen > ep.dataCap { + return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) + } + return recvDataLen, nil +} + +// SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or +// Endpoint.RecvFirst() to return with the given datagram length. +// +// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or +// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. +// If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. +func (ep *Endpoint) SendLast(dataLen uint32) error { + if dataLen > ep.dataCap { + panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) + } + *ep.dataLen() = dataLen + if err := ep.ctrlWakeLast(); err != nil { + return err + } + return nil +} diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index 572a1f119..edb6a8bef 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -17,6 +17,7 @@ package flipcall import ( "bytes" "fmt" + "sync" ) func Example() { @@ -36,20 +37,21 @@ func Example() { if err != nil { panic(err) } - clientEP, err := NewEndpoint(ControlModeFutex, pwd) - if err != nil { + var clientEP Endpoint + if err := clientEP.Init(pwd); err != nil { panic(err) } defer clientEP.Destroy() - serverEP, err := NewEndpoint(ControlModeFutex, pwd) - if err != nil { + var serverEP Endpoint + if err := serverEP.Init(pwd); err != nil { panic(err) } defer serverEP.Destroy() - serverDone := make(chan struct{}) + var serverRun sync.WaitGroup + serverRun.Add(1) go func() { - defer func() { serverDone <- struct{}{} }() + defer serverRun.Done() i := 0 var buf bytes.Buffer // wait for first request @@ -76,9 +78,13 @@ func Example() { }() defer func() { serverEP.Shutdown() - <-serverDone + serverRun.Wait() }() + // establish connection as client + if err := clientEP.Connect(); err != nil { + panic(err) + } var buf bytes.Buffer for i := 0; i < count; i++ { // write request diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 20d3002f0..da9d736ab 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -15,197 +15,240 @@ package flipcall import ( + "runtime" + "sync" "testing" "time" ) var testPacketWindowSize = pageSize -func testSendRecv(t *testing.T, ctrlMode ControlMode) { - pwa, err := NewPacketWindowAllocator() - if err != nil { - t.Fatalf("failed to create PacketWindowAllocator: %v", err) +type testConnection struct { + pwa PacketWindowAllocator + clientEP Endpoint + serverEP Endpoint +} + +func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection { + c := &testConnection{} + if err := c.pwa.Init(); err != nil { + tb.Fatalf("failed to create PacketWindowAllocator: %v", err) } - defer pwa.Destroy() - pwd, err := pwa.Allocate(testPacketWindowSize) + pwd, err := c.pwa.Allocate(testPacketWindowSize) if err != nil { - t.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) + c.pwa.Destroy() + tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) } - - sendEP, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - t.Fatalf("failed to create Endpoint: %v", err) + if err := c.clientEP.Init(pwd, clientOpts...); err != nil { + c.pwa.Destroy() + tb.Fatalf("failed to create client Endpoint: %v", err) } - defer sendEP.Destroy() - recvEP, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - t.Fatalf("failed to create Endpoint: %v", err) + if err := c.serverEP.Init(pwd, serverOpts...); err != nil { + c.pwa.Destroy() + c.clientEP.Destroy() + tb.Fatalf("failed to create server Endpoint: %v", err) } - defer recvEP.Destroy() + return c +} - otherThreadDone := make(chan struct{}) +func newTestConnection(tb testing.TB) *testConnection { + return newTestConnectionWithOptions(tb, nil, nil) +} + +func (c *testConnection) destroy() { + c.pwa.Destroy() + c.clientEP.Destroy() + c.serverEP.Destroy() +} + +func testSendRecv(t *testing.T, c *testConnection) { + var serverRun sync.WaitGroup + serverRun.Add(1) go func() { - defer func() { otherThreadDone <- struct{}{} }() - t.Logf("initially-inactive Endpoint waiting for packet 1") - if _, err := recvEP.RecvFirst(); err != nil { - t.Fatalf("initially-inactive Endpoint.RecvFirst() failed: %v", err) + defer serverRun.Done() + t.Logf("server Endpoint waiting for packet 1") + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) } - t.Logf("initially-inactive Endpoint got packet 1, sending packet 2 and waiting for packet 3") - if _, err := recvEP.SendRecv(0); err != nil { - t.Fatalf("initially-inactive Endpoint.SendRecv() failed: %v", err) + t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") + if _, err := c.serverEP.SendRecv(0); err != nil { + t.Fatalf("server Endpoint.SendRecv() failed: %v", err) } - t.Logf("initially-inactive Endpoint got packet 3") + t.Logf("server Endpoint got packet 3") }() defer func() { - t.Logf("waiting for initially-inactive Endpoint goroutine to complete") - <-otherThreadDone + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() }() - t.Logf("initially-active Endpoint sending packet 1 and waiting for packet 2") - if _, err := sendEP.SendRecv(0); err != nil { - t.Fatalf("initially-active Endpoint.SendRecv() failed: %v", err) + t.Logf("client Endpoint establishing connection") + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) } - t.Logf("initially-active Endpoint got packet 2, sending packet 3") - if err := sendEP.SendLast(0); err != nil { - t.Fatalf("initially-active Endpoint.SendLast() failed: %v", err) + t.Logf("client Endpoint sending packet 1 and waiting for packet 2") + if _, err := c.clientEP.SendRecv(0); err != nil { + t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + t.Logf("client Endpoint got packet 2, sending packet 3") + if err := c.clientEP.SendLast(0); err != nil { + t.Fatalf("client Endpoint.SendLast() failed: %v", err) + } + t.Logf("waiting for server goroutine to complete") + serverRun.Wait() } -func TestFutexSendRecv(t *testing.T) { - testSendRecv(t, ControlModeFutex) +func TestSendRecv(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testSendRecv(t, c) } -func testRecvFirstShutdown(t *testing.T, ctrlMode ControlMode) { - pwa, err := NewPacketWindowAllocator() - if err != nil { - t.Fatalf("failed to create PacketWindowAllocator: %v", err) - } - defer pwa.Destroy() - pwd, err := pwa.Allocate(testPacketWindowSize) - if err != nil { - t.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) - } +func testShutdownConnect(t *testing.T, c *testConnection) { + var clientRun sync.WaitGroup + clientRun.Add(1) + go func() { + defer clientRun.Done() + if err := c.clientEP.Connect(); err == nil { + t.Errorf("client Endpoint.Connect() succeeded unexpectedly") + } + }() + time.Sleep(time.Second) // to allow c.clientEP.Connect() to block + c.clientEP.Shutdown() + clientRun.Wait() +} - ep, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - t.Fatalf("failed to create Endpoint: %v", err) - } - defer ep.Destroy() +func TestShutdownConnect(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownConnect(t, c) +} - otherThreadDone := make(chan struct{}) +func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) { + var serverRun sync.WaitGroup + serverRun.Add(1) go func() { - defer func() { otherThreadDone <- struct{}{} }() - _, err := ep.RecvFirst() + defer serverRun.Done() + _, err := c.serverEP.RecvFirst() if err == nil { - t.Errorf("Endpoint.RecvFirst() succeeded unexpectedly") + t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() - - time.Sleep(time.Second) // to ensure ep.RecvFirst() has blocked - ep.Shutdown() - <-otherThreadDone + time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block + c.serverEP.Shutdown() + serverRun.Wait() } -func TestFutexRecvFirstShutdown(t *testing.T) { - testRecvFirstShutdown(t, ControlModeFutex) +func TestShutdownRecvFirstBeforeConnect(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownRecvFirstBeforeConnect(t, c) } -func testSendRecvShutdown(t *testing.T, ctrlMode ControlMode) { - pwa, err := NewPacketWindowAllocator() - if err != nil { - t.Fatalf("failed to create PacketWindowAllocator: %v", err) - } - defer pwa.Destroy() - pwd, err := pwa.Allocate(testPacketWindowSize) - if err != nil { - t.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) - } - - sendEP, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - t.Fatalf("failed to create Endpoint: %v", err) - } - defer sendEP.Destroy() - recvEP, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - t.Fatalf("failed to create Endpoint: %v", err) - } - defer recvEP.Destroy() - - otherThreadDone := make(chan struct{}) +func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) { + var serverRun sync.WaitGroup + serverRun.Add(1) go func() { - defer func() { otherThreadDone <- struct{}{} }() - if _, err := recvEP.RecvFirst(); err != nil { - t.Fatalf("initially-inactive Endpoint.RecvFirst() failed: %v", err) - } - if _, err := recvEP.SendRecv(0); err == nil { - t.Errorf("initially-inactive Endpoint.SendRecv() succeeded unexpectedly") + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err == nil { + t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly") } }() - - if _, err := sendEP.SendRecv(0); err != nil { - t.Fatalf("initially-active Endpoint.SendRecv() failed: %v", err) + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) } - time.Sleep(time.Second) // to ensure recvEP.SendRecv() has blocked - recvEP.Shutdown() - <-otherThreadDone + c.serverEP.Shutdown() + serverRun.Wait() } -func TestFutexSendRecvShutdown(t *testing.T) { - testSendRecvShutdown(t, ControlModeFutex) +func TestShutdownRecvFirstAfterConnect(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownRecvFirstAfterConnect(t, c) } -func benchmarkSendRecv(b *testing.B, ctrlMode ControlMode) { - pwa, err := NewPacketWindowAllocator() - if err != nil { - b.Fatalf("failed to create PacketWindowAllocator: %v", err) +func testShutdownSendRecv(t *testing.T, c *testConnection) { + var serverRun sync.WaitGroup + serverRun.Add(1) + go func() { + defer serverRun.Done() + if _, err := c.serverEP.RecvFirst(); err != nil { + t.Fatalf("server Endpoint.RecvFirst() failed: %v", err) + } + if _, err := c.serverEP.SendRecv(0); err == nil { + t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly") + } + }() + defer func() { + // Ensure that the server goroutine is cleaned up before + // c.serverEP.Destroy(), even if the test fails. + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + t.Fatalf("client Endpoint.Connect() failed: %v", err) } - defer pwa.Destroy() - pwd, err := pwa.Allocate(testPacketWindowSize) - if err != nil { - b.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err) + if _, err := c.clientEP.SendRecv(0); err != nil { + t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + time.Sleep(time.Second) // to allow serverEP.SendRecv() to block + c.serverEP.Shutdown() + serverRun.Wait() +} - sendEP, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - b.Fatalf("failed to create Endpoint: %v", err) - } - defer sendEP.Destroy() - recvEP, err := NewEndpoint(ctrlMode, pwd) - if err != nil { - b.Fatalf("failed to create Endpoint: %v", err) - } - defer recvEP.Destroy() +func TestShutdownSendRecv(t *testing.T) { + c := newTestConnection(t) + defer c.destroy() + testShutdownSendRecv(t, c) +} - otherThreadDone := make(chan struct{}) +func benchmarkSendRecv(b *testing.B, c *testConnection) { + var serverRun sync.WaitGroup + serverRun.Add(1) go func() { - defer func() { otherThreadDone <- struct{}{} }() + defer serverRun.Done() if b.N == 0 { return } - if _, err := recvEP.RecvFirst(); err != nil { - b.Fatalf("initially-inactive Endpoint.RecvFirst() failed: %v", err) + if _, err := c.serverEP.RecvFirst(); err != nil { + b.Fatalf("server Endpoint.RecvFirst() failed: %v", err) } for i := 1; i < b.N; i++ { - if _, err := recvEP.SendRecv(0); err != nil { - b.Fatalf("initially-inactive Endpoint.SendRecv() failed: %v", err) + if _, err := c.serverEP.SendRecv(0); err != nil { + b.Fatalf("server Endpoint.SendRecv() failed: %v", err) } } - if err := recvEP.SendLast(0); err != nil { - b.Fatalf("initially-inactive Endpoint.SendLast() failed: %v", err) + if err := c.serverEP.SendLast(0); err != nil { + b.Fatalf("server Endpoint.SendLast() failed: %v", err) } }() - defer func() { <-otherThreadDone }() + defer func() { + c.serverEP.Shutdown() + serverRun.Wait() + }() + if err := c.clientEP.Connect(); err != nil { + b.Fatalf("client Endpoint.Connect() failed: %v", err) + } + runtime.GC() b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := sendEP.SendRecv(0); err != nil { - b.Fatalf("initially-active Endpoint.SendRecv() failed: %v", err) + if _, err := c.clientEP.SendRecv(0); err != nil { + b.Fatalf("client Endpoint.SendRecv() failed: %v", err) } } b.StopTimer() } -func BenchmarkFutexSendRecv(b *testing.B) { - benchmarkSendRecv(b, ControlModeFutex) +func BenchmarkSendRecv(b *testing.B) { + c := newTestConnection(b) + defer c.destroy() + benchmarkSendRecv(b, c) } diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go new file mode 100644 index 000000000..7c8977893 --- /dev/null +++ b/pkg/flipcall/flipcall_unsafe.go @@ -0,0 +1,69 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "reflect" + "unsafe" +) + +// Packets consist of an 8-byte header followed by an arbitrarily-sized +// datagram. The header consists of: +// +// - A 4-byte native-endian connection state. +// +// - A 4-byte native-endian datagram length in bytes. +const ( + sizeofUint32 = unsafe.Sizeof(uint32(0)) + + // PacketHeaderBytes is the size of a flipcall packet header in bytes. The + // maximum datagram size supported by a flipcall connection is equal to the + // length of the packet window minus PacketHeaderBytes. + // + // PacketHeaderBytes is exported to support its use in constant + // expressions. Non-constant expressions may prefer to use + // PacketWindowLengthForDataCap(). + PacketHeaderBytes = 2 * sizeofUint32 +) + +func (ep *Endpoint) connState() *uint32 { + return (*uint32)((unsafe.Pointer)(ep.packet)) +} + +func (ep *Endpoint) dataLen() *uint32 { + return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32)) +} + +// Data returns the datagram part of ep's packet window as a byte slice. +// +// Note that the packet window is shared with the potentially-untrusted peer +// Endpoint, which may concurrently mutate the contents of the packet window. +// Thus: +// +// - Readers must not assume that two reads of the same byte in Data() will +// return the same result. In other words, readers should read any given byte +// in Data() at most once. +// +// - Writers must not assume that they will read back the same data that they +// have written. In other words, writers should avoid reading from Data() at +// all. +func (ep *Endpoint) Data() []byte { + var bs []byte + bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs)) + bsReflect.Data = ep.packet + PacketHeaderBytes + bsReflect.Len = int(ep.dataCap) + bsReflect.Cap = int(ep.dataCap) + return bs +} diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index 3f592ad16..e7dd812b3 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -17,78 +17,95 @@ package flipcall import ( + "encoding/json" "fmt" - "math" + "runtime" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/log" ) -func (ep *Endpoint) doFutexRoundTrip() error { - ourSeq, err := ep.doFutexNotifySeq() - if err != nil { - return err +func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeResponse, error) { + var resp ctrlHandshakeResponse + + // Write the handshake request. + w := ep.NewWriter() + if err := json.NewEncoder(w).Encode(req); err != nil { + return resp, fmt.Errorf("error writing handshake request: %v", err) } - return ep.doFutexWaitSeq(ourSeq) -} + *ep.dataLen() = w.Len() -func (ep *Endpoint) doFutexWaitFirst() error { - return ep.doFutexWaitSeq(0) -} + // Exchange control with the server. + if err := ep.futexSwitchToPeer(); err != nil { + return resp, err + } + if err := ep.futexSwitchFromPeer(); err != nil { + return resp, err + } -func (ep *Endpoint) doFutexNotifyLast() error { - _, err := ep.doFutexNotifySeq() - return err + // Read the handshake response. + respLen := atomic.LoadUint32(ep.dataLen()) + if respLen > ep.dataCap { + return resp, fmt.Errorf("invalid handshake response length %d (maximum %d)", respLen, ep.dataCap) + } + if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil { + return resp, fmt.Errorf("error reading handshake response: %v", err) + } + + return resp, nil } -func (ep *Endpoint) doFutexNotifySeq() (uint32, error) { - ourSeq := atomic.AddUint32(ep.seq(), 1) - if err := ep.futexWake(1); err != nil { - return ourSeq, fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err) +func (ep *Endpoint) futexSwitchToPeer() error { + // Update connection state to indicate that the peer should be active. + if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState())) } - return ourSeq, nil + + // Wake the peer's Endpoint.futexSwitchFromPeer(). + if err := ep.futexWakeConnState(1); err != nil { + return fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err) + } + return nil } -func (ep *Endpoint) doFutexWaitSeq(prevSeq uint32) error { - nextSeq := prevSeq + 1 +func (ep *Endpoint) futexSwitchFromPeer() error { for { - if ep.isShutdown() { - return endpointShutdownError{} - } - if err := ep.futexWait(prevSeq); err != nil { - return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) - } - seq := atomic.LoadUint32(ep.seq()) - if seq == nextSeq { + switch cs := atomic.LoadUint32(ep.connState()); cs { + case ep.activeState: return nil + case ep.inactiveState: + // Continue to FUTEX_WAIT. + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } - if seq != prevSeq { - return fmt.Errorf("invalid packet sequence number %d (expected %d or %d)", seq, prevSeq, nextSeq) + if ep.isShutdownLocally() { + return shutdownError{} + } + if err := ep.futexWaitConnState(ep.inactiveState); err != nil { + return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) } } } -func (ep *Endpoint) doFutexInterruptForShutdown() { - // Wake MaxInt32 threads to prevent a malicious or broken peer from - // swallowing our wakeup by FUTEX_WAITing from multiple threads. - if err := ep.futexWake(math.MaxInt32); err != nil { - log.Warningf("failed to FUTEX_WAKE Endpoint: %v", err) - } -} - -func (ep *Endpoint) futexWake(numThreads int32) error { - if _, _, e := syscall.RawSyscall(syscall.SYS_FUTEX, uintptr(ep.packet), linux.FUTEX_WAKE, uintptr(numThreads)); e != 0 { +func (ep *Endpoint) futexWakeConnState(numThreads int32) error { + if _, _, e := syscall.RawSyscall(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAKE, uintptr(numThreads)); e != 0 { return e } return nil } -func (ep *Endpoint) futexWait(seq uint32) error { - _, _, e := syscall.Syscall6(syscall.SYS_FUTEX, uintptr(ep.packet), linux.FUTEX_WAIT, uintptr(seq), 0, 0, 0) +func (ep *Endpoint) futexWaitConnState(curState uint32) error { + _, _, e := syscall.Syscall6(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAIT, uintptr(curState), 0, 0, 0) if e != 0 && e != syscall.EAGAIN && e != syscall.EINTR { return e } return nil } + +func yieldThread() { + syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) + // The thread we're trying to yield to may be waiting for a Go runtime P. + // runtime.Gosched() will hand off ours if necessary. + runtime.Gosched() +} diff --git a/pkg/flipcall/io.go b/pkg/flipcall/io.go new file mode 100644 index 000000000..85e40b932 --- /dev/null +++ b/pkg/flipcall/io.go @@ -0,0 +1,113 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flipcall + +import ( + "fmt" + "io" +) + +// DatagramReader implements io.Reader by reading a datagram from an Endpoint's +// packet window. Its use is optional; users that can use Endpoint.Data() more +// efficiently are advised to do so. +type DatagramReader struct { + ep *Endpoint + off uint32 + end uint32 +} + +// Init must be called on zero-value DatagramReaders before first use. +// +// Preconditions: dataLen is 0, or was returned by a previous call to +// ep.RecvFirst() or ep.SendRecv(). +func (r *DatagramReader) Init(ep *Endpoint, dataLen uint32) { + r.ep = ep + r.Reset(dataLen) +} + +// Reset causes r to begin reading a new datagram of the given length from the +// associated Endpoint. +// +// Preconditions: dataLen is 0, or was returned by a previous call to the +// associated Endpoint's RecvFirst() or SendRecv() methods. +func (r *DatagramReader) Reset(dataLen uint32) { + if dataLen > r.ep.dataCap { + panic(fmt.Sprintf("invalid dataLen (%d) > ep.dataCap (%d)", dataLen, r.ep.dataCap)) + } + r.off = 0 + r.end = dataLen +} + +// NewReader is a convenience function that returns an initialized +// DatagramReader allocated on the heap. +// +// Preconditions: dataLen was returned by a previous call to ep.RecvFirst() or +// ep.SendRecv(). +func (ep *Endpoint) NewReader(dataLen uint32) *DatagramReader { + r := &DatagramReader{} + r.Init(ep, dataLen) + return r +} + +// Read implements io.Reader.Read. +func (r *DatagramReader) Read(dst []byte) (int, error) { + n := copy(dst, r.ep.Data()[r.off:r.end]) + r.off += uint32(n) + if r.off == r.end { + return n, io.EOF + } + return n, nil +} + +// DatagramWriter implements io.Writer by writing a datagram to an Endpoint's +// packet window. Its use is optional; users that can use Endpoint.Data() more +// efficiently are advised to do so. +type DatagramWriter struct { + ep *Endpoint + off uint32 +} + +// Init must be called on zero-value DatagramWriters before first use. +func (w *DatagramWriter) Init(ep *Endpoint) { + w.ep = ep +} + +// Reset causes w to begin writing a new datagram to the associated Endpoint. +func (w *DatagramWriter) Reset() { + w.off = 0 +} + +// NewWriter is a convenience function that returns an initialized +// DatagramWriter allocated on the heap. +func (ep *Endpoint) NewWriter() *DatagramWriter { + w := &DatagramWriter{} + w.Init(ep) + return w +} + +// Write implements io.Writer.Write. +func (w *DatagramWriter) Write(src []byte) (int, error) { + n := copy(w.ep.Data()[w.off:w.ep.dataCap], src) + w.off += uint32(n) + if n != len(src) { + return n, fmt.Errorf("datagram would exceed maximum size of %d bytes", w.ep.dataCap) + } + return n, nil +} + +// Len returns the length of the written datagram. +func (w *DatagramWriter) Len() uint32 { + return w.off +} diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go index 7b455b24d..ccb918fab 100644 --- a/pkg/flipcall/packet_window_allocator.go +++ b/pkg/flipcall/packet_window_allocator.go @@ -34,10 +34,10 @@ func init() { // This is depended on by roundUpToPage(). panic(fmt.Sprintf("system page size (%d) is not a power of 2", pageSize)) } - if uintptr(pageSize) < packetHeaderBytes { + if uintptr(pageSize) < PacketHeaderBytes { // This is required since Endpoint.Init() imposes a minimum packet // window size of 1 page. - panic(fmt.Sprintf("system page size (%d) is less than packet header size (%d)", pageSize, packetHeaderBytes)) + panic(fmt.Sprintf("system page size (%d) is less than packet header size (%d)", pageSize, PacketHeaderBytes)) } } @@ -59,7 +59,7 @@ type PacketWindowDescriptor struct { // PacketWindowLengthForDataCap returns the minimum packet window size required // to accommodate datagrams of the given size in bytes. func PacketWindowLengthForDataCap(dataCap uint32) int { - return roundUpToPage(int(dataCap) + int(packetHeaderBytes)) + return roundUpToPage(int(dataCap) + int(PacketHeaderBytes)) } func roundUpToPage(x int) int { diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 8c3e3d5ab..ad69e0757 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -215,8 +215,8 @@ type AtomicRefCount struct { type LeakMode uint32 const ( - // uninitializedLeakChecking indicates that the leak checker has not yet been initialized. - uninitializedLeakChecking LeakMode = iota + // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. + UninitializedLeakChecking LeakMode = iota // NoLeakChecking indicates that no effort should be made to check for // leaks. @@ -231,20 +231,6 @@ const ( LeaksLogTraces ) -// String returns LeakMode's string representation. -func (l LeakMode) String() string { - switch l { - case NoLeakChecking: - return "disabled" - case LeaksLogWarning: - return "log-names" - case LeaksLogTraces: - return "log-traces" - default: - panic(fmt.Sprintf("Invalid leakmode: %d", l)) - } -} - // leakMode stores the current mode for the reference leak checker. // // Values must be one of the LeakMode values. @@ -332,13 +318,15 @@ func (r *AtomicRefCount) finalize() { switch LeakMode(atomic.LoadUint32(&leakMode)) { case NoLeakChecking: return - case uninitializedLeakChecking: + case UninitializedLeakChecking: note = "(Leak checker uninitialized): " } if n := r.ReadRefs(); n != 0 { msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n) if len(r.stack) != 0 { msg += ":\nCaller:\n" + formatStack(r.stack) + } else { + msg += " (enable trace logging to debug)" } log.Warningf(msg) } diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index 62ae1fd9f..48413f1fb 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -70,7 +70,7 @@ func main() { syscall.SYS_NANOSLEEP: {}, syscall.SYS_NEWFSTATAT: {}, syscall.SYS_OPEN: {}, - syscall.SYS_POLL: {}, + syscall.SYS_PPOLL: {}, syscall.SYS_PREAD64: {}, syscall.SYS_PSELECT6: {}, syscall.SYS_PWRITE64: {}, diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 3f9772b87..c35faeb4c 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -56,15 +56,10 @@ type ExecArgs struct { // MountNamespace is the mount namespace to execute the new process in. // A reference on MountNamespace must be held for the lifetime of the - // ExecArgs. If MountNamespace is nil, it will default to the kernel's - // root MountNamespace. + // ExecArgs. If MountNamespace is nil, it will default to the init + // process's MountNamespace. MountNamespace *fs.MountNamespace - // Root defines the root directory for the new process. A reference on - // Root must be held for the lifetime of the ExecArgs. If Root is nil, - // it will default to the VFS root. - Root *fs.Dirent - // WorkingDirectory defines the working directory for the new process. WorkingDirectory string `json:"wd"` @@ -155,7 +150,6 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI Envv: args.Envv, WorkingDirectory: args.WorkingDirectory, MountNamespace: args.MountNamespace, - Root: args.Root, Credentials: creds, FDTable: fdTable, Umask: 0022, @@ -167,11 +161,6 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI ContainerID: args.ContainerID, PIDNamespace: args.PIDNamespace, } - if initArgs.Root != nil { - // initArgs must hold a reference on Root, which will be - // donated to the new process in CreateProcess. - initArgs.Root.IncRef() - } if initArgs.MountNamespace != nil { // initArgs must hold a reference on MountNamespace, which will // be donated to the new process in CreateProcess. @@ -184,7 +173,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI paths := fs.GetPath(initArgs.Envv) mns := initArgs.MountNamespace if mns == nil { - mns = proc.Kernel.RootMountNamespace() + mns = proc.Kernel.GlobalInit().Leader().MountNamespace() } f, err := mns.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths) if err != nil { diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go index 9fc6a5bc2..4f3d6410e 100644 --- a/pkg/sentry/fs/attr.go +++ b/pkg/sentry/fs/attr.go @@ -111,6 +111,50 @@ func (n InodeType) LinuxType() uint32 { } } +// ToDirentType converts an InodeType to a linux dirent type field. +func ToDirentType(nodeType InodeType) uint8 { + switch nodeType { + case RegularFile, SpecialFile: + return linux.DT_REG + case Symlink: + return linux.DT_LNK + case Directory, SpecialDirectory: + return linux.DT_DIR + case Pipe: + return linux.DT_FIFO + case CharacterDevice: + return linux.DT_CHR + case BlockDevice: + return linux.DT_BLK + case Socket: + return linux.DT_SOCK + default: + return linux.DT_UNKNOWN + } +} + +// ToInodeType coverts a linux file type to InodeType. +func ToInodeType(linuxFileType linux.FileMode) InodeType { + switch linuxFileType { + case linux.ModeRegular: + return RegularFile + case linux.ModeDirectory: + return Directory + case linux.ModeSymlink: + return Symlink + case linux.ModeNamedPipe: + return Pipe + case linux.ModeCharacterDevice: + return CharacterDevice + case linux.ModeBlockDevice: + return BlockDevice + case linux.ModeSocket: + return Socket + default: + panic(fmt.Sprintf("unknown file mode: %d", linuxFileType)) + } +} + // StableAttr contains Inode attributes that will be stable throughout the // lifetime of the Inode. // diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go index 51b4c7ee1..dd427de5d 100644 --- a/pkg/sentry/fs/context.go +++ b/pkg/sentry/fs/context.go @@ -112,3 +112,27 @@ func DirentCacheLimiterFromContext(ctx context.Context) *DirentCacheLimiter { } return nil } + +type rootContext struct { + context.Context + root *Dirent +} + +// WithRoot returns a copy of ctx with the given root. +func WithRoot(ctx context.Context, root *Dirent) context.Context { + return &rootContext{ + Context: ctx, + root: root, + } +} + +// Value implements Context.Value. +func (rc rootContext) Value(key interface{}) interface{} { + switch key { + case CtxRoot: + rc.root.IncRef() + return rc.root + default: + return rc.Context.Value(key) + } +} diff --git a/pkg/sentry/fs/ext/BUILD b/pkg/sentry/fs/ext/BUILD deleted file mode 100644 index 2c15875f5..000000000 --- a/pkg/sentry/fs/ext/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "ext", - srcs = [ - "dentry.go", - "ext.go", - "filesystem.go", - "inode.go", - "utils.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ext", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/context", - "//pkg/sentry/fs/ext/disklayout", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//pkg/syserror", - ], -) - -go_test( - name = "ext_test", - size = "small", - srcs = [ - "ext_test.go", - "extent_test.go", - ], - data = [ - "//pkg/sentry/fs/ext:assets/bigfile.txt", - "//pkg/sentry/fs/ext:assets/file.txt", - "//pkg/sentry/fs/ext:assets/tiny.ext2", - "//pkg/sentry/fs/ext:assets/tiny.ext3", - "//pkg/sentry/fs/ext:assets/tiny.ext4", - ], - embed = [":ext"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs/ext/disklayout", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/vfs", - "//runsc/test/testutil", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", - ], -) diff --git a/pkg/sentry/fs/ext/ext.go b/pkg/sentry/fs/ext/ext.go deleted file mode 100644 index 10e235fb1..000000000 --- a/pkg/sentry/fs/ext/ext.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ext implements readonly ext(2/3/4) filesystems. -package ext - -import ( - "errors" - "fmt" - "io" - "os" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// filesystemType implements vfs.FilesystemType. -type filesystemType struct{} - -// Compiles only if filesystemType implements vfs.FilesystemType. -var _ vfs.FilesystemType = (*filesystemType)(nil) - -// getDeviceFd returns the read seeker to the underlying device. -// Currently there are two ways of mounting an ext(2/3/4) fs: -// 1. Specify a mount with our internal special MountType in the OCI spec. -// 2. Expose the device to the container and mount it from application layer. -func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReadSeeker, error) { - if opts.InternalData == nil { - // User mount call. - // TODO(b/134676337): Open the device specified by `source` and return that. - panic("unimplemented") - } - - // NewFilesystem call originated from within the sentry. - fd, ok := opts.InternalData.(uintptr) - if !ok { - return nil, errors.New("internal data for ext fs must be a uintptr containing the file descriptor to device") - } - - // We do not close this file because that would close the underlying device - // file descriptor (which is required for reading the fs from disk). - // TODO(b/134676337): Use pkg/fd instead. - deviceFile := os.NewFile(fd, source) - if deviceFile == nil { - return nil, fmt.Errorf("ext4 device file descriptor is not valid: %d", fd) - } - - return deviceFile, nil -} - -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (fstype filesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - dev, err := getDeviceFd(source, opts) - if err != nil { - return nil, nil, err - } - - fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} - fs.vfsfs.Init(&fs) - fs.sb, err = readSuperBlock(dev) - if err != nil { - return nil, nil, err - } - - if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { - // mount(2) specifies that EINVAL should be returned if the superblock is - // invalid. - return nil, nil, syserror.EINVAL - } - - fs.bgs, err = readBlockGroups(dev, fs.sb) - if err != nil { - return nil, nil, err - } - - rootInode, err := fs.getOrCreateInode(disklayout.RootDirInode) - if err != nil { - return nil, nil, err - } - - return &fs.vfsfs, &newDentry(rootInode).vfsd, nil -} diff --git a/pkg/sentry/fs/ext/ext_test.go b/pkg/sentry/fs/ext/ext_test.go deleted file mode 100644 index ee7f7907c..000000000 --- a/pkg/sentry/fs/ext/ext_test.go +++ /dev/null @@ -1,407 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "fmt" - "os" - "path" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - - "gvisor.dev/gvisor/runsc/test/testutil" -) - -const ( - assetsDir = "pkg/sentry/fs/ext/assets" -) - -var ( - ext2ImagePath = path.Join(assetsDir, "tiny.ext2") - ext3ImagePath = path.Join(assetsDir, "tiny.ext3") - ext4ImagePath = path.Join(assetsDir, "tiny.ext4") -) - -func beginning(_ uint64) uint64 { - return 0 -} - -func middle(i uint64) uint64 { - return i / 2 -} - -func end(i uint64) uint64 { - return i -} - -// setUp opens imagePath as an ext Filesystem and returns all necessary -// elements required to run tests. If error is non-nil, it also returns a tear -// down function which must be called after the test is run for clean up. -func setUp(t *testing.T, imagePath string) (context.Context, *vfs.Filesystem, *vfs.Dentry, func(), error) { - localImagePath, err := testutil.FindFile(imagePath) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err) - } - - f, err := os.Open(localImagePath) - if err != nil { - return nil, nil, nil, nil, err - } - - // Mount the ext4 fs and retrieve the inode structure for the file. - mockCtx := contexttest.Context(t) - fs, d, err := filesystemType{}.NewFilesystem(mockCtx, nil, localImagePath, vfs.NewFilesystemOptions{InternalData: f.Fd()}) - if err != nil { - f.Close() - return nil, nil, nil, nil, err - } - - tearDown := func() { - if err := f.Close(); err != nil { - t.Fatalf("tearDown failed: %v", err) - } - } - return mockCtx, fs, d, tearDown, nil -} - -// TestRootDir tests that the root directory inode is correctly initialized and -// returned from setUp. -func TestRootDir(t *testing.T) { - type inodeProps struct { - Mode linux.FileMode - UID auth.KUID - GID auth.KGID - Size uint64 - InodeSize uint16 - Links uint16 - Flags disklayout.InodeFlags - } - - type rootDirTest struct { - name string - image string - wantInode inodeProps - } - - tests := []rootDirTest{ - { - name: "ext4 root dir", - image: ext4ImagePath, - wantInode: inodeProps{ - Mode: linux.ModeDirectory | 0755, - Size: 0x400, - InodeSize: 0x80, - Links: 3, - Flags: disklayout.InodeFlags{Extents: true}, - }, - }, - { - name: "ext3 root dir", - image: ext3ImagePath, - wantInode: inodeProps{ - Mode: linux.ModeDirectory | 0755, - Size: 0x400, - InodeSize: 0x80, - Links: 3, - }, - }, - { - name: "ext2 root dir", - image: ext2ImagePath, - wantInode: inodeProps{ - Mode: linux.ModeDirectory | 0755, - Size: 0x400, - InodeSize: 0x80, - Links: 3, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - _, _, vfsd, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - d, ok := vfsd.Impl().(*dentry) - if !ok { - t.Fatalf("ext dentry of incorrect type: %T", vfsd.Impl()) - } - - // Offload inode contents into local structs for comparison. - gotInode := inodeProps{ - Mode: d.inode.diskInode.Mode(), - UID: d.inode.diskInode.UID(), - GID: d.inode.diskInode.GID(), - Size: d.inode.diskInode.Size(), - InodeSize: d.inode.diskInode.InodeSize(), - Links: d.inode.diskInode.LinksCount(), - Flags: d.inode.diskInode.Flags(), - } - - if diff := cmp.Diff(gotInode, test.wantInode); diff != "" { - t.Errorf("inode mismatch (-want +got):\n%s", diff) - } - }) - } -} - -// TestFilesystemInit tests that the filesystem superblock and block group -// descriptors are correctly read in and initialized. -func TestFilesystemInit(t *testing.T) { - // sb only contains the immutable properties of the superblock. - type sb struct { - InodesCount uint32 - BlocksCount uint64 - MaxMountCount uint16 - FirstDataBlock uint32 - BlockSize uint64 - BlocksPerGroup uint32 - ClusterSize uint64 - ClustersPerGroup uint32 - InodeSize uint16 - InodesPerGroup uint32 - BgDescSize uint16 - Magic uint16 - Revision disklayout.SbRevision - CompatFeatures disklayout.CompatFeatures - IncompatFeatures disklayout.IncompatFeatures - RoCompatFeatures disklayout.RoCompatFeatures - } - - // bg only contains the immutable properties of the block group descriptor. - type bg struct { - InodeTable uint64 - BlockBitmap uint64 - InodeBitmap uint64 - ExclusionBitmap uint64 - Flags disklayout.BGFlags - } - - type fsInitTest struct { - name string - image string - wantSb sb - wantBgs []bg - } - - tests := []fsInitTest{ - { - name: "ext4 filesystem init", - image: ext4ImagePath, - wantSb: sb{ - InodesCount: 0x10, - BlocksCount: 0x40, - MaxMountCount: 0xffff, - FirstDataBlock: 0x1, - BlockSize: 0x400, - BlocksPerGroup: 0x2000, - ClusterSize: 0x400, - ClustersPerGroup: 0x2000, - InodeSize: 0x80, - InodesPerGroup: 0x10, - BgDescSize: 0x40, - Magic: linux.EXT_SUPER_MAGIC, - Revision: disklayout.DynamicRev, - CompatFeatures: disklayout.CompatFeatures{ - ExtAttr: true, - ResizeInode: true, - DirIndex: true, - }, - IncompatFeatures: disklayout.IncompatFeatures{ - DirentFileType: true, - Extents: true, - Is64Bit: true, - FlexBg: true, - }, - RoCompatFeatures: disklayout.RoCompatFeatures{ - Sparse: true, - LargeFile: true, - HugeFile: true, - DirNlink: true, - ExtraIsize: true, - MetadataCsum: true, - }, - }, - wantBgs: []bg{ - { - InodeTable: 0x23, - BlockBitmap: 0x3, - InodeBitmap: 0x13, - Flags: disklayout.BGFlags{ - InodeZeroed: true, - }, - }, - }, - }, - { - name: "ext3 filesystem init", - image: ext3ImagePath, - wantSb: sb{ - InodesCount: 0x10, - BlocksCount: 0x40, - MaxMountCount: 0xffff, - FirstDataBlock: 0x1, - BlockSize: 0x400, - BlocksPerGroup: 0x2000, - ClusterSize: 0x400, - ClustersPerGroup: 0x2000, - InodeSize: 0x80, - InodesPerGroup: 0x10, - BgDescSize: 0x20, - Magic: linux.EXT_SUPER_MAGIC, - Revision: disklayout.DynamicRev, - CompatFeatures: disklayout.CompatFeatures{ - ExtAttr: true, - ResizeInode: true, - DirIndex: true, - }, - IncompatFeatures: disklayout.IncompatFeatures{ - DirentFileType: true, - }, - RoCompatFeatures: disklayout.RoCompatFeatures{ - Sparse: true, - LargeFile: true, - }, - }, - wantBgs: []bg{ - { - InodeTable: 0x5, - BlockBitmap: 0x3, - InodeBitmap: 0x4, - Flags: disklayout.BGFlags{ - InodeZeroed: true, - }, - }, - }, - }, - { - name: "ext2 filesystem init", - image: ext2ImagePath, - wantSb: sb{ - InodesCount: 0x10, - BlocksCount: 0x40, - MaxMountCount: 0xffff, - FirstDataBlock: 0x1, - BlockSize: 0x400, - BlocksPerGroup: 0x2000, - ClusterSize: 0x400, - ClustersPerGroup: 0x2000, - InodeSize: 0x80, - InodesPerGroup: 0x10, - BgDescSize: 0x20, - Magic: linux.EXT_SUPER_MAGIC, - Revision: disklayout.DynamicRev, - CompatFeatures: disklayout.CompatFeatures{ - ExtAttr: true, - ResizeInode: true, - DirIndex: true, - }, - IncompatFeatures: disklayout.IncompatFeatures{ - DirentFileType: true, - }, - RoCompatFeatures: disklayout.RoCompatFeatures{ - Sparse: true, - LargeFile: true, - }, - }, - wantBgs: []bg{ - { - InodeTable: 0x5, - BlockBitmap: 0x3, - InodeBitmap: 0x4, - Flags: disklayout.BGFlags{ - InodeZeroed: true, - }, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - _, vfsfs, _, tearDown, err := setUp(t, test.image) - if err != nil { - t.Fatalf("setUp failed: %v", err) - } - defer tearDown() - - fs, ok := vfsfs.Impl().(*filesystem) - if !ok { - t.Fatalf("ext filesystem of incorrect type: %T", vfsfs.Impl()) - } - - // Offload superblock and block group descriptors contents into - // local structs for comparison. - totalFreeInodes := uint32(0) - totalFreeBlocks := uint64(0) - gotSb := sb{ - InodesCount: fs.sb.InodesCount(), - BlocksCount: fs.sb.BlocksCount(), - MaxMountCount: fs.sb.MaxMountCount(), - FirstDataBlock: fs.sb.FirstDataBlock(), - BlockSize: fs.sb.BlockSize(), - BlocksPerGroup: fs.sb.BlocksPerGroup(), - ClusterSize: fs.sb.ClusterSize(), - ClustersPerGroup: fs.sb.ClustersPerGroup(), - InodeSize: fs.sb.InodeSize(), - InodesPerGroup: fs.sb.InodesPerGroup(), - BgDescSize: fs.sb.BgDescSize(), - Magic: fs.sb.Magic(), - Revision: fs.sb.Revision(), - CompatFeatures: fs.sb.CompatibleFeatures(), - IncompatFeatures: fs.sb.IncompatibleFeatures(), - RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(), - } - gotNumBgs := len(fs.bgs) - gotBgs := make([]bg, gotNumBgs) - for i := 0; i < gotNumBgs; i++ { - gotBgs[i].InodeTable = fs.bgs[i].InodeTable() - gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap() - gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap() - gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap() - gotBgs[i].Flags = fs.bgs[i].Flags() - - totalFreeInodes += fs.bgs[i].FreeInodesCount() - totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount()) - } - - if diff := cmp.Diff(gotSb, test.wantSb); diff != "" { - t.Errorf("superblock mismatch (-want +got):\n%s", diff) - } - - if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" { - t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff) - } - - if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" { - t.Errorf("total free inodes mismatch (-want +got):\n%s", diff) - } - - if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" { - t.Errorf("total free blocks mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/sentry/fs/ext/filesystem.go b/pkg/sentry/fs/ext/filesystem.go deleted file mode 100644 index 7150e75a5..000000000 --- a/pkg/sentry/fs/ext/filesystem.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "io" - "sync" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" -) - -// filesystem implements vfs.FilesystemImpl. -type filesystem struct { - // TODO(b/134676337): Remove when all methods have been implemented. - vfs.FilesystemImpl - - vfsfs vfs.Filesystem - - // mu serializes changes to the Dentry tree and the usage of the read seeker. - mu sync.Mutex - - // dev is the ReadSeeker for the underlying fs device. It is protected by mu. - // - // The ext filesystems aim to maximize locality, i.e. place all the data - // blocks of a file close together. On a spinning disk, locality reduces the - // amount of movement of the head hence speeding up IO operations. On an SSD - // there are no moving parts but locality increases the size of each transer - // request. Hence, having mutual exclusion on the read seeker while reading a - // file *should* help in achieving the intended performance gains. - // - // Note: This synchronization was not coupled with the ReadSeeker itself - // because we want to synchronize across read/seek operations for the - // performance gains mentioned above. Helps enforcing one-file-at-a-time IO. - dev io.ReadSeeker - - // inodeCache maps absolute inode numbers to the corresponding Inode struct. - // Inodes should be removed from this once their reference count hits 0. - // - // Protected by mu because every addition and removal from this corresponds to - // a change in the dentry tree. - inodeCache map[uint32]*inode - - // sb represents the filesystem superblock. Immutable after initialization. - sb disklayout.SuperBlock - - // bgs represents all the block group descriptors for the filesystem. - // Immutable after initialization. - bgs []disklayout.BlockGroup -} - -// Compiles only if filesystem implements vfs.FilesystemImpl. -var _ vfs.FilesystemImpl = (*filesystem)(nil) - -// getOrCreateInode gets the inode corresponding to the inode number passed in. -// It creates a new one with the given inode number if one does not exist. -// -// Preconditions: must be holding fs.mu. -func (fs *filesystem) getOrCreateInode(inodeNum uint32) (*inode, error) { - if in, ok := fs.inodeCache[inodeNum]; ok { - return in, nil - } - - in, err := newInode(fs.dev, fs.sb, fs.bgs, inodeNum) - if err != nil { - return nil, err - } - - fs.inodeCache[inodeNum] = in - return in, nil -} - -// Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { -} - -// Sync implements vfs.FilesystemImpl.Sync. -func (fs *filesystem) Sync(ctx context.Context) error { - // This is a readonly filesystem for now. - return nil -} - -// The vfs.FilesystemImpl functions below return EROFS because their respective -// man pages say that EROFS must be returned if the path resolves to a file on -// a read-only filesystem. - -// TODO(b/134676337): Implement path traversal and return EROFS only if the -// path resolves to a Dentry within ext fs. - -// MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return syserror.EROFS -} - -// MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return syserror.EROFS -} - -// RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { - return syserror.EROFS -} - -// RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - return syserror.EROFS -} - -// SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - return syserror.EROFS -} - -// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return syserror.EROFS -} - -// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - return syserror.EROFS -} diff --git a/pkg/sentry/fs/ext/inode.go b/pkg/sentry/fs/ext/inode.go deleted file mode 100644 index df1ea0bda..000000000 --- a/pkg/sentry/fs/ext/inode.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ext - -import ( - "io" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" - "gvisor.dev/gvisor/pkg/syserror" -) - -// inode represents an ext inode. -type inode struct { - // refs is a reference count. refs is accessed using atomic memory operations. - refs int64 - - // inodeNum is the inode number of this inode on disk. This is used to - // identify inodes within the ext filesystem. - inodeNum uint32 - - // diskInode gives us access to the inode struct on disk. Immutable. - diskInode disklayout.Inode - - // root is the root extent node. This lives in the 60 byte diskInode.Blocks(). - // Immutable. Nil if the inode does not use extents. - root *disklayout.ExtentNode -} - -// incRef increments the inode ref count. -func (in *inode) incRef() { - atomic.AddInt64(&in.refs, 1) -} - -// tryIncRef tries to increment the ref count. Returns true if successful. -func (in *inode) tryIncRef() bool { - for { - refs := atomic.LoadInt64(&in.refs) - if refs == 0 { - return false - } - if atomic.CompareAndSwapInt64(&in.refs, refs, refs+1) { - return true - } - } -} - -// decRef decrements the inode ref count and releases the inode resources if -// the ref count hits 0. -// -// Preconditions: Must have locked fs.mu. -func (in *inode) decRef(fs *filesystem) { - if refs := atomic.AddInt64(&in.refs, -1); refs == 0 { - delete(fs.inodeCache, in.inodeNum) - } else if refs < 0 { - panic("ext.inode.decRef() called without holding a reference") - } -} - -// newInode is the inode constructor. Reads the inode off disk. Identifies -// inodes based on the absolute inode number on disk. -// -// Preconditions: Must hold the mutex of the filesystem containing dev. -func newInode(dev io.ReadSeeker, sb disklayout.SuperBlock, bgs []disklayout.BlockGroup, inodeNum uint32) (*inode, error) { - if inodeNum == 0 { - panic("inode number 0 on ext filesystems is not possible") - } - - in := &inode{refs: 1, inodeNum: inodeNum} - inodeRecordSize := sb.InodeSize() - if inodeRecordSize == disklayout.OldInodeSize { - in.diskInode = &disklayout.InodeOld{} - } else { - in.diskInode = &disklayout.InodeNew{} - } - - // Calculate where the inode is actually placed. - inodesPerGrp := sb.InodesPerGroup() - blkSize := sb.BlockSize() - inodeTableOff := bgs[getBGNum(inodeNum, inodesPerGrp)].InodeTable() * blkSize - inodeOff := inodeTableOff + uint64(uint32(inodeRecordSize)*getBGOff(inodeNum, inodesPerGrp)) - - // Read it from disk and figure out which type of inode this is. - if err := readFromDisk(dev, int64(inodeOff), in.diskInode); err != nil { - return nil, err - } - - if in.diskInode.Flags().Extents { - in.buildExtTree(dev, blkSize) - } - - return in, nil -} - -// getBGNum returns the block group number that a given inode belongs to. -func getBGNum(inodeNum uint32, inodesPerGrp uint32) uint32 { - return (inodeNum - 1) / inodesPerGrp -} - -// getBGOff returns the offset at which the given inode lives in the block -// group's inode table, i.e. the index of the inode in the inode table. -func getBGOff(inodeNum uint32, inodesPerGrp uint32) uint32 { - return (inodeNum - 1) % inodesPerGrp -} - -// buildExtTree builds the extent tree by reading it from disk by doing -// running a simple DFS. It first reads the root node from the inode struct in -// memory. Then it recursively builds the rest of the tree by reading it off -// disk. -// -// Preconditions: -// - Must hold the mutex of the filesystem containing dev. -// - Inode flag InExtents must be set. -func (in *inode) buildExtTree(dev io.ReadSeeker, blkSize uint64) error { - rootNodeData := in.diskInode.Data() - - var rootHeader disklayout.ExtentHeader - binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &rootHeader) - - // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. - if rootHeader.NumEntries > 4 { - // read(2) specifies that EINVAL should be returned if the file is unsuitable - // for reading. - return syserror.EINVAL - } - - rootEntries := make([]disklayout.ExtentEntryPair, rootHeader.NumEntries) - for i, off := uint16(0), disklayout.ExtentStructsSize; i < rootHeader.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize { - var curEntry disklayout.ExtentEntry - if rootHeader.Height == 0 { - // Leaf node. - curEntry = &disklayout.Extent{} - } else { - // Internal node. - curEntry = &disklayout.ExtentIdx{} - } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry) - rootEntries[i].Entry = curEntry - } - - // If this node is internal, perform DFS. - if rootHeader.Height > 0 { - for i := uint16(0); i < rootHeader.NumEntries; i++ { - var err error - if rootEntries[i].Node, err = buildExtTreeFromDisk(dev, rootEntries[i].Entry, blkSize); err != nil { - return err - } - } - } - - in.root = &disklayout.ExtentNode{rootHeader, rootEntries} - return nil -} - -// buildExtTreeFromDisk reads the extent tree nodes from disk and recursively -// builds the tree. Performs a simple DFS. It returns the ExtentNode pointed to -// by the ExtentEntry. -// -// Preconditions: Must hold the mutex of the filesystem containing dev. -func buildExtTreeFromDisk(dev io.ReadSeeker, entry disklayout.ExtentEntry, blkSize uint64) (*disklayout.ExtentNode, error) { - var header disklayout.ExtentHeader - off := entry.PhysicalBlock() * blkSize - if err := readFromDisk(dev, int64(off), &header); err != nil { - return nil, err - } - - entries := make([]disklayout.ExtentEntryPair, header.NumEntries) - for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize { - var curEntry disklayout.ExtentEntry - if header.Height == 0 { - // Leaf node. - curEntry = &disklayout.Extent{} - } else { - // Internal node. - curEntry = &disklayout.ExtentIdx{} - } - - if err := readFromDisk(dev, int64(off), curEntry); err != nil { - return nil, err - } - entries[i].Entry = curEntry - } - - // If this node is internal, perform DFS. - if header.Height > 0 { - for i := uint16(0); i < header.NumEntries; i++ { - var err error - entries[i].Node, err = buildExtTreeFromDisk(dev, entries[i].Entry, blkSize) - if err != nil { - return nil, err - } - } - } - - return &disklayout.ExtentNode{header, entries}, nil -} diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 5a0a67eab..669ffcb75 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -87,7 +87,7 @@ func (p *pipeOperations) init() error { log.Warningf("pipe: cannot stat fd %d: %v", p.file.FD(), err) return syscall.EINVAL } - if s.Mode&syscall.S_IFIFO != syscall.S_IFIFO { + if (s.Mode & syscall.S_IFMT) != syscall.S_IFIFO { log.Warningf("pipe: cannot load fd %d as pipe, file type: %o", p.file.FD(), s.Mode) return syscall.EINVAL } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index ed62049a9..20cb9a367 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -66,10 +66,8 @@ type CachingInodeOperations struct { // mfp is used to allocate memory that caches backingFile's contents. mfp pgalloc.MemoryFileProvider - // forcePageCache indicates the sentry page cache should be used regardless - // of whether the platform supports host mapped I/O or not. This must not be - // modified after inode creation. - forcePageCache bool + // opts contains options. opts is immutable. + opts CachingInodeOperationsOptions attrMu sync.Mutex `state:"nosave"` @@ -116,6 +114,20 @@ type CachingInodeOperations struct { refs frameRefSet } +// CachingInodeOperationsOptions configures a CachingInodeOperations. +// +// +stateify savable +type CachingInodeOperationsOptions struct { + // If ForcePageCache is true, use the sentry page cache even if a host file + // descriptor is available. + ForcePageCache bool + + // If LimitHostFDTranslation is true, apply maxFillRange() constraints to + // host file descriptor mappings returned by + // CachingInodeOperations.Translate(). + LimitHostFDTranslation bool +} + // CachedFileObject is a file that may require caching. type CachedFileObject interface { // ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts, @@ -159,7 +171,7 @@ type CachedFileObject interface { // NewCachingInodeOperations returns a new CachingInodeOperations backed by // a CachedFileObject and its initial unstable attributes. -func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations { +func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations { mfp := pgalloc.MemoryFileProviderFromContext(ctx) if mfp == nil { panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider)) @@ -167,7 +179,7 @@ func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject return &CachingInodeOperations{ backingFile: backingFile, mfp: mfp, - forcePageCache: forcePageCache, + opts: opts, attr: uattr, hostFileMapper: NewHostFileMapper(), } @@ -568,21 +580,30 @@ type inodeReadWriter struct { // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + mem := rw.c.mfp.MemoryFile() + fillCache := !rw.c.useHostPageCache() && mem.ShouldCacheEvictable() + // Hot path. Avoid defers. - rw.c.dataMu.RLock() + var unlock func() + if fillCache { + rw.c.dataMu.Lock() + unlock = rw.c.dataMu.Unlock + } else { + rw.c.dataMu.RLock() + unlock = rw.c.dataMu.RUnlock + } // Compute the range to read. if rw.offset >= rw.c.attr.Size { - rw.c.dataMu.RUnlock() + unlock() return 0, io.EOF } end := fs.ReadEndOffset(rw.offset, int64(dsts.NumBytes()), rw.c.attr.Size) if end == rw.offset { // dsts.NumBytes() == 0? - rw.c.dataMu.RUnlock() + unlock() return 0, nil } - mem := rw.c.mfp.MemoryFile() var done uint64 seg, gap := rw.c.cache.Find(uint64(rw.offset)) for rw.offset < end { @@ -592,7 +613,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Get internal mappings from the cache. ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) if err != nil { - rw.c.dataMu.RUnlock() + unlock() return done, err } @@ -602,7 +623,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { rw.offset += int64(n) dsts = dsts.DropFirst64(n) if err != nil { - rw.c.dataMu.RUnlock() + unlock() return done, err } @@ -610,27 +631,48 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { seg, gap = seg.NextNonEmpty() case gap.Ok(): - // Read directly from the backing file. - gapmr := gap.Range().Intersect(mr) - dst := dsts.TakeFirst64(gapmr.Length()) - n, err := rw.c.backingFile.ReadToBlocksAt(rw.ctx, dst, gapmr.Start) - done += n - rw.offset += int64(n) - dsts = dsts.DropFirst64(n) - // Partial reads are fine. But we must stop reading. - if n != dst.NumBytes() || err != nil { - rw.c.dataMu.RUnlock() - return done, err + gapMR := gap.Range().Intersect(mr) + if fillCache { + // Read into the cache, then re-enter the loop to read from the + // cache. + reqMR := memmap.MappableRange{ + Start: uint64(usermem.Addr(gapMR.Start).RoundDown()), + End: fs.OffsetPageEnd(int64(gapMR.End)), + } + optMR := gap.Range() + err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt) + mem.MarkEvictable(rw.c, pgalloc.EvictableRange{optMR.Start, optMR.End}) + seg, gap = rw.c.cache.Find(uint64(rw.offset)) + if !seg.Ok() { + unlock() + return done, err + } + // err might have occurred in part of gap.Range() outside + // gapMR. Forget about it for now; if the error matters and + // persists, we'll run into it again in a later iteration of + // this loop. + } else { + // Read directly from the backing file. + dst := dsts.TakeFirst64(gapMR.Length()) + n, err := rw.c.backingFile.ReadToBlocksAt(rw.ctx, dst, gapMR.Start) + done += n + rw.offset += int64(n) + dsts = dsts.DropFirst64(n) + // Partial reads are fine. But we must stop reading. + if n != dst.NumBytes() || err != nil { + unlock() + return done, err + } + + // Continue. + seg, gap = gap.NextSegment(), FileRangeGapIterator{} } - // Continue. - seg, gap = gap.NextSegment(), FileRangeGapIterator{} - default: break } } - rw.c.dataMu.RUnlock() + unlock() return done, nil } @@ -700,7 +742,10 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error seg, gap = seg.NextNonEmpty() case gap.Ok() && gap.Start() < mr.End: - // Write directly to the backing file. + // Write directly to the backing file. At present, we never fill + // the cache when writing, since doing so can convert small writes + // into inefficient read-modify-write cycles, and we have no + // mechanism for detecting or avoiding this. gapmr := gap.Range().Intersect(mr) src := srcs.TakeFirst64(gapmr.Length()) n, err := rw.c.backingFile.WriteFromBlocksAt(rw.ctx, src, gapmr.Start) @@ -730,7 +775,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error // and memory mappings, and false if c.cache may contain data cached from // c.backingFile. func (c *CachingInodeOperations) useHostPageCache() bool { - return !c.forcePageCache && c.backingFile.FD() >= 0 + return !c.opts.ForcePageCache && c.backingFile.FD() >= 0 } // AddMapping implements memmap.Mappable.AddMapping. @@ -802,11 +847,15 @@ func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.Mapp func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { // Hot path. Avoid defer. if c.useHostPageCache() { + mr := optional + if c.opts.LimitHostFDTranslation { + mr = maxFillRange(required, optional) + } return []memmap.Translation{ { - Source: optional, + Source: mr, File: c, - Offset: optional.Start, + Offset: mr.Start, Perms: usermem.AnyAccess, }, }, nil diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go index dc19255ed..eb5730c35 100644 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ b/pkg/sentry/fs/fsutil/inode_cached_test.go @@ -61,7 +61,7 @@ func TestSetPermissions(t *testing.T) { uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ Perms: fs.FilePermsFromMode(0444), }) - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) defer iops.Release() perms := fs.FilePermsFromMode(0777) @@ -150,7 +150,7 @@ func TestSetTimestamps(t *testing.T) { ModificationTime: epoch, StatusChangeTime: epoch, } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) defer iops.Release() if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil { @@ -188,7 +188,7 @@ func TestTruncate(t *testing.T) { uattr := fs.UnstableAttr{ Size: 0, } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{}) defer iops.Release() if err := iops.Truncate(ctx, nil, uattr.Size); err != nil { @@ -280,7 +280,7 @@ func TestRead(t *testing.T) { uattr := fs.UnstableAttr{ Size: int64(len(buf)), } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) defer iops.Release() // Expect the cache to be initially empty. @@ -336,7 +336,7 @@ func TestWrite(t *testing.T) { uattr := fs.UnstableAttr{ Size: int64(len(buf)), } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/) + iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{}) defer iops.Release() // Expect the cache to be initially empty. diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go index 69999dc28..8f8ab5d29 100644 --- a/pkg/sentry/fs/gofer/fs.go +++ b/pkg/sentry/fs/gofer/fs.go @@ -54,6 +54,10 @@ const ( // sandbox using files backed by the gofer. If set to false, unix sockets // cannot be bound to gofer files without an overlay on top. privateUnixSocketKey = "privateunixsocket" + + // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to + // true. + limitHostFDTranslationKey = "limit_host_fd_translation" ) // defaultAname is the default attach name. @@ -134,12 +138,13 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou // opts are parsed 9p mount options. type opts struct { - fd int - aname string - policy cachePolicy - msize uint32 - version string - privateunixsocket bool + fd int + aname string + policy cachePolicy + msize uint32 + version string + privateunixsocket bool + limitHostFDTranslation bool } // options parses mount(2) data into structured options. @@ -237,6 +242,11 @@ func options(data string) (opts, error) { delete(options, privateUnixSocketKey) } + if _, ok := options[limitHostFDTranslationKey]; ok { + o.limitHostFDTranslation = true + delete(options, limitHostFDTranslationKey) + } + // Fail to attach if the caller wanted us to do something that we // don't support. if len(options) > 0 { diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 69d08a627..50da865c1 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -117,6 +117,11 @@ type session struct { // Flags provided to the mount. superBlockFlags fs.MountSourceFlags `state:"wait"` + // limitHostFDTranslation is the value used for + // CachingInodeOperationsOptions.LimitHostFDTranslation for all + // CachingInodeOperations created by the session. + limitHostFDTranslation bool + // connID is a unique identifier for the session connection. connID string `state:"wait"` @@ -218,8 +223,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p uattr := unstable(ctx, valid, attr, s.mounter, s.client) return sattr, &inodeOperations{ - fileState: fileState, - cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache), + fileState: fileState, + cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{ + ForcePageCache: s.superBlockFlags.ForcePageCache, + LimitHostFDTranslation: s.limitHostFDTranslation, + }), } } @@ -242,13 +250,14 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF // Construct the session. s := session{ - connID: dev, - msize: o.msize, - version: o.version, - cachePolicy: o.policy, - aname: o.aname, - superBlockFlags: superBlockFlags, - mounter: mounter, + connID: dev, + msize: o.msize, + version: o.version, + cachePolicy: o.policy, + aname: o.aname, + superBlockFlags: superBlockFlags, + limitHostFDTranslation: o.limitHostFDTranslation, + mounter: mounter, } s.EnableLeakCheck("gofer.session") diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 679d8321a..894ab01f0 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -200,8 +200,10 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, // Build the fs.InodeOperations. uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s) iops := &inodeOperations{ - fileState: fileState, - cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache), + fileState: fileState, + cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{ + ForcePageCache: msrc.Flags.ForcePageCache, + }), } // Return the fs.Inode. diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 44c4ee5f2..2392787cb 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -65,7 +65,7 @@ type ConnectedEndpoint struct { // GetSockOpt and message splitting/rejection in SendMsg, but do not // prevent lots of small messages from filling the real send buffer // size on the host. - sndbuf int `state:"nosave"` + sndbuf int64 `state:"nosave"` // mu protects the fields below. mu sync.RWMutex `state:"nosave"` @@ -107,7 +107,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { } c.stype = linux.SockType(stype) - c.sndbuf = sndbuf + c.sndbuf = int64(sndbuf) return nil } @@ -202,7 +202,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) } // Send implements transport.ConnectedEndpoint.Send. -func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -279,7 +279,7 @@ func (c *ConnectedEndpoint) EventUpdate() { } // Recv implements transport.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index 05d7c79ad..af6955675 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -55,19 +55,19 @@ func copyFromMulti(dst []byte, src [][]byte) { // // If intermediate != nil, iovecs references intermediate rather than bufs and // the caller must copy to/from bufs as necessary. -func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) { +func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) { var iovsRequired int for _, b := range bufs { - length += uintptr(len(b)) + length += int64(len(b)) if len(b) > 0 { iovsRequired++ } } stopLen := length - if length > uintptr(maxlen) { + if length > maxlen { if truncate { - stopLen = uintptr(maxlen) + stopLen = maxlen err = syserror.EAGAIN } else { return 0, nil, nil, syserror.EMSGSIZE @@ -85,7 +85,7 @@ func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovec }}, b, err } - var total uintptr + var total int64 iovecs = make([]syscall.Iovec, 0, iovsRequired) for i := range bufs { l := len(bufs[i]) @@ -93,9 +93,9 @@ func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovec continue } - stop := l - if total+uintptr(stop) > stopLen { - stop = int(stopLen - total) + stop := int64(l) + if total+stop > stopLen { + stop = stopLen - total } iovecs = append(iovecs, syscall.Iovec{ @@ -103,7 +103,7 @@ func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovec Len: uint64(stop), }) - total += uintptr(stop) + total += stop if total >= stopLen { break } diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index e57be0506..f3bbed7ea 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -23,7 +23,7 @@ import ( // // If the total length of bufs is > maxlen, fdReadVec will do a partial read // and err will indicate why the message was truncated. -func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, controlTrunc bool, err error) { +func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) { flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC) if peek { flags |= syscall.MSG_PEEK @@ -48,11 +48,12 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re msg.Iovlen = uint64(len(iovecs)) } - n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) + rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) if e != 0 { // N.B. prioritize the syscall error over the buildIovec error. return 0, 0, 0, false, e } + n := int64(rawN) // Copy data back to bufs. if intermediate != nil { @@ -72,7 +73,7 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re // // If the total length of bufs is > maxlen && truncate, fdWriteVec will do a // partial write and err will indicate why the message was truncated. -func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) { +func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) { length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate) if err != nil && len(iovecs) == 0 { // No partial write to do, return error immediately. @@ -96,5 +97,5 @@ func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uint return 0, length, e } - return n, length, err + return int64(n), length, err } diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 693ffc760..ac0398bd9 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -171,8 +171,6 @@ type MountNamespace struct { // NewMountNamespace returns a new MountNamespace, with the provided node at the // root, and the given cache size. A root must always be provided. func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) { - creds := auth.CredentialsFromContext(ctx) - // Set the root dirent and id on the root mount. The reference returned from // NewDirent will be donated to the MountNamespace constructed below. d := NewDirent(ctx, root, "/") @@ -181,6 +179,7 @@ func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error d: newRootMount(1, d), } + creds := auth.CredentialsFromContext(ctx) mns := MountNamespace{ userns: creds.UserNamespace, root: d, @@ -219,6 +218,13 @@ func (mns *MountNamespace) flushMountSourceRefsLocked() { } } + if mns.root == nil { + // No root? This MountSource must have already been destroyed. + // This can happen when a Save is triggered while a process is + // exiting. There is nothing to flush. + return + } + // Flush root's MountSource references. mns.root.Inode.MountSource.FlushDirentRefs() } @@ -249,6 +255,10 @@ func (mns *MountNamespace) destroy() { // Drop reference on the root. mns.root.DecRef() + // Ensure that root cannot be accessed via this MountNamespace any + // more. + mns.root = nil + // Wait for asynchronous work (queued by dropping Dirent references // above) to complete before destroying this MountNamespace. AsyncBarrier() @@ -678,7 +688,7 @@ func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name s return "", syserror.ENOENT } -// GetPath returns the PATH as a slice of strings given the environemnt +// GetPath returns the PATH as a slice of strings given the environment // variables. func GetPath(env []string) []string { const prefix = "PATH=" diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 6b839685b..9adb23608 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -348,7 +348,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se // Field: local_adddress. var localAddr linux.SockAddrInet if local, _, err := sops.GetSockName(t); err == nil { - localAddr = local.(linux.SockAddrInet) + localAddr = *local.(*linux.SockAddrInet) } binary.LittleEndian.PutUint16(portBuf, localAddr.Port) fmt.Fprintf(&buf, "%08X:%04X ", @@ -358,7 +358,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se // Field: rem_address. var remoteAddr linux.SockAddrInet if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = remote.(linux.SockAddrInet) + remoteAddr = *remote.(*linux.SockAddrInet) } binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port) fmt.Fprintf(&buf, "%08X:%04X ", diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index f3e984c24..78e082b8e 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -53,7 +53,6 @@ type Dir struct { fsutil.InodeGenericChecker `state:"nosave"` fsutil.InodeIsDirAllocate `state:"nosave"` fsutil.InodeIsDirTruncate `state:"nosave"` - fsutil.InodeNoopRelease `state:"nosave"` fsutil.InodeNoopWriteOut `state:"nosave"` fsutil.InodeNotMappable `state:"nosave"` fsutil.InodeNotSocket `state:"nosave"` @@ -84,7 +83,8 @@ type Dir struct { var _ fs.InodeOperations = (*Dir)(nil) -// NewDir returns a new Dir with the given contents and attributes. +// NewDir returns a new Dir with the given contents and attributes. A reference +// on each fs.Inode in the `contents` map will be donated to this Dir. func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions) *Dir { d := &Dir{ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, perms, linux.RAMFS_MAGIC), @@ -138,7 +138,7 @@ func (d *Dir) addChildLocked(ctx context.Context, name string, inode *fs.Inode) d.NotifyModificationAndStatusChange(ctx) } -// AddChild adds a child to this dir. +// AddChild adds a child to this dir, inheriting its reference. func (d *Dir) AddChild(ctx context.Context, name string, inode *fs.Inode) { d.mu.Lock() defer d.mu.Unlock() @@ -172,7 +172,9 @@ func (d *Dir) Children() ([]string, map[string]fs.DentAttr) { return namesCopy, entriesCopy } -// removeChildLocked attempts to remove an entry from this directory. +// removeChildLocked attempts to remove an entry from this directory. It +// returns the removed fs.Inode along with its reference, which callers are +// responsible for decrementing. func (d *Dir) removeChildLocked(ctx context.Context, name string) (*fs.Inode, error) { inode, ok := d.children[name] if !ok { @@ -253,7 +255,8 @@ func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) err return nil } -// Lookup loads an inode at p into a Dirent. +// Lookup loads an inode at p into a Dirent. It returns the fs.Dirent along +// with a reference. func (d *Dir) Lookup(ctx context.Context, _ *fs.Inode, p string) (*fs.Dirent, error) { if len(p) > linux.NAME_MAX { return nil, syserror.ENAMETOOLONG @@ -408,6 +411,16 @@ func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, ol return Rename(ctx, oldParent.InodeOperations, oldName, newParent.InodeOperations, newName, replacement) } +// Release implements fs.InodeOperation.Release. +func (d *Dir) Release(_ context.Context) { + // Drop references on all children. + d.mu.Lock() + for _, i := range d.children { + i.DecRef() + } + d.mu.Unlock() +} + // dirFileOperations implements fs.FileOperations for a ramfs directory. // // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 0f4497cd6..159fb7c08 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -56,7 +56,6 @@ func rename(ctx context.Context, oldParent *fs.Inode, oldName string, newParent type Dir struct { fsutil.InodeGenericChecker `state:"nosave"` fsutil.InodeIsDirTruncate `state:"nosave"` - fsutil.InodeNoopRelease `state:"nosave"` fsutil.InodeNoopWriteOut `state:"nosave"` fsutil.InodeNotMappable `state:"nosave"` fsutil.InodeNotSocket `state:"nosave"` @@ -252,6 +251,11 @@ func (d *Dir) Allocate(ctx context.Context, node *fs.Inode, offset, length int64 return d.ramfsDir.Allocate(ctx, node, offset, length) } +// Release implements fs.InodeOperations.Release. +func (d *Dir) Release(ctx context.Context) { + d.ramfsDir.Release(ctx) +} + // Symlink is a symlink. // // +stateify savable diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5e9327aec..291164986 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 1d128532b..2f639c823 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { // Release implements fs.InodeOperations.Release. func (d *dirInodeOperations) Release(ctx context.Context) { + d.mu.Lock() + defer d.mu.Unlock() + d.master.DecRef() if len(d.slaves) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 92ec1ca18..19b7557d5 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -172,6 +172,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme return 0, mf.t.ld.windowSize(ctx, io, args) case linux.TIOCSWINSZ: return 0, mf.t.ld.setWindowSize(ctx, io, args) + case linux.TIOCSCTTY: + // Make the given terminal the controlling terminal of the + // calling process. + return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY @@ -185,8 +198,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TCSETS, linux.TCSETSW, linux.TCSETSF, - linux.TIOCGPGRP, - linux.TIOCSPGRP, linux.TIOCGWINSZ, linux.TIOCSWINSZ, linux.TIOCSETD, @@ -200,8 +211,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TIOCEXCL, linux.TIOCNXCL, linux.TIOCGEXCL, - linux.TIOCNOTTY, - linux.TIOCSCTTY, linux.TIOCGSID, linux.TIOCGETD, linux.TIOCVHANGUP, diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index e30266404..944c4ada1 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -152,9 +152,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - // TODO(b/129283598): Implement once we have support for job - // control. - return 0, nil + return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index b7cecb2ed..ff8138820 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,7 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) // Terminal is a pseudoterminal. @@ -26,23 +29,100 @@ import ( type Terminal struct { refs.AtomicRefCount - // n is the terminal index. + // n is the terminal index. It is immutable. n uint32 - // d is the containing directory. + // d is the containing directory. It is immutable. d *dirInodeOperations - // ld is the line discipline of the terminal. + // ld is the line discipline of the terminal. It is immutable. ld *lineDiscipline + + // masterKTTY contains the controlling process of the master end of + // this terminal. This field is immutable. + masterKTTY *kernel.TTY + + // slaveKTTY contains the controlling process of the slave end of this + // terminal. This field is immutable. + slaveKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { termios := linux.DefaultSlaveTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{}, + slaveKTTY: &kernel.TTY{}, } t.EnableLeakCheck("tty.Terminal") return &t } + +// setControllingTTY makes tm the controlling terminal of the calling thread +// group. +func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setControllingTTY must be called from a task context") + } + + return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int()) +} + +// releaseControllingTTY removes tm as the controlling terminal of the calling +// thread group. +func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("releaseControllingTTY must be called from a task context") + } + + return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster)) +} + +// foregroundProcessGroup gets the process group ID of tm's foreground process. +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("foregroundProcessGroup must be called from a task context") + } + + ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster)) + if err != nil { + return 0, err + } + + // Write it out to *arg. + _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err +} + +// foregroundProcessGroup sets tm's foreground process. +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setForegroundProcessGroup must be called from a task context") + } + + // Read in the process group ID. + var pgid int32 + if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + + ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid)) + return uintptr(ret), err +} + +func (tm *Terminal) tty(isMaster bool) *kernel.TTY { + if isMaster { + return tm.masterKTTY + } + return tm.slaveKTTY +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD new file mode 100644 index 000000000..a41101339 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -0,0 +1,86 @@ +package(licenses = ["notice"]) + +load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +go_template_instance( + name = "dirent_list", + out = "dirent_list.go", + package = "ext", + prefix = "dirent", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*dirent", + "Linker": "*dirent", + }, +) + +go_library( + name = "ext", + srcs = [ + "block_map_file.go", + "dentry.go", + "directory.go", + "dirent_list.go", + "ext.go", + "extent_file.go", + "file_description.go", + "filesystem.go", + "inode.go", + "regular_file.go", + "symlink.go", + "utils.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/fd", + "//pkg/log", + "//pkg/sentry/arch", + "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/fsimpl/ext/disklayout", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/safemem", + "//pkg/sentry/syscalls/linux", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//pkg/waiter", + ], +) + +go_test( + name = "ext_test", + size = "small", + srcs = [ + "block_map_test.go", + "ext_test.go", + "extent_test.go", + ], + data = [ + "//pkg/sentry/fsimpl/ext:assets/bigfile.txt", + "//pkg/sentry/fsimpl/ext:assets/file.txt", + "//pkg/sentry/fsimpl/ext:assets/tiny.ext2", + "//pkg/sentry/fsimpl/ext:assets/tiny.ext3", + "//pkg/sentry/fsimpl/ext:assets/tiny.ext4", + ], + embed = [":ext"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/fsimpl/ext/disklayout", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//runsc/test/testutil", + "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + ], +) diff --git a/pkg/sentry/fsimpl/ext/README.md b/pkg/sentry/fsimpl/ext/README.md new file mode 100644 index 000000000..af00cfda8 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/README.md @@ -0,0 +1,117 @@ +## EXT(2/3/4) File System + +This is a filesystem driver which supports ext2, ext3 and ext4 filesystems. +Linux has specialized drivers for each variant but none which supports all. This +library takes advantage of ext's backward compatibility and understands the +internal organization of on-disk structures to support all variants. + +This driver implementation diverges from the Linux implementations in being more +forgiving about versioning. For instance, if a filesystem contains both extent +based inodes and classical block map based inodes, this driver will not complain +and interpret them both correctly. While in Linux this would be an issue. This +blurs the line between the three ext fs variants. + +Ext2 is considered deprecated as of Red Hat Enterprise Linux 7, and ext3 has +been superseded by ext4 by large performance gains. Thus it is recommended to +upgrade older filesystem images to ext4 using e2fsprogs for better performance. + +### Read Only + +This driver currently only allows read only operations. A lot of the design +decisions are based on this feature. There are plans to implement write (the +process for which is documented in the future work section). + +### Performance + +One of the biggest wins about this driver is that it directly talks to the +underlying block device (or whatever persistent storage is being used), instead +of making expensive RPCs to a gofer. + +Another advantage is that ext fs supports fast concurrent reads. Currently the +device is represented using a `io.ReaderAt` which allows for concurrent reads. +All reads are directly passed to the device driver which intelligently serves +the read requests in the optimal order. There is no congestion due to locking +while reading in the filesystem level. + +Reads are optimized further in the way file data is transferred over to user +memory. Ext fs directly copies over file data from disk into user memory with no +additional allocations on the way. We can only get faster by preloading file +data into memory (see future work section). + +The internal structures used to represent files, inodes and file descriptors use +a lot of inheritance. With the level of indirection that an interface adds with +an internal pointer, it can quickly fragment a structure across memory. As this +runs along side a full blown kernel (which is memory intensive), having a +fragmented struct might hurt performance. Hence these internal structures, +though interfaced, are tightly packed in memory using the same inheritance +pattern that pkg/sentry/vfs uses. The pkg/sentry/fsimpl/ext/disklayout package +makes an execption to this pattern for reasons documented in the package. + +### Security + +This driver also intends to help sandbox the container better by reducing the +surface of the host kernel that the application touches. It prevents the +application from exploiting vulnerabilities in the host filesystem driver. All +`io.ReaderAt.ReadAt()` calls are translated to `pread(2)` which are directly +passed to the device driver in the kernel. Hence this reduces the surface for +attack. + +The application can not affect any host filesystems other than the one passed +via block device by the user. + +### Future Work + +#### Write + +To support write operations we would need to modify the block device underneath. +Currently, the driver does not modify the device at all, not even for updating +the access times for reads. Modifying the filesystem incorrectly can corrupt it +and render it unreadable for other correct ext(x) drivers. Hence caution must be +maintained while modifying metadata structures. + +Ext4 specifically is built for performance and has added a lot of complexity as +to how metadata structures are modified. For instance, files that are organized +via an extent tree which must be balanced and file data blocks must be placed in +the same extent as much as possible to increase locality. Such properties must +be maintained while modifying the tree. + +Ext filesystems boast a lot about locality, which plays a big role in them being +performant. The block allocation algorithm in Linux does a good job in keeping +related data together. This behavior must be maintained as much as possible, +else we might end up degrading the filesystem performance over time. + +Ext4 also supports a wide variety of features which are specialized for varying +use cases. Implementing all of them can get difficult very quickly. + +Ext(x) checksums all its metadata structures to check for corruption, so +modification of any metadata struct must correspond with re-checksumming the +struct. Linux filesystem drivers also order on-disk updates intelligently to not +corrupt the filesystem and also remain performant. The in-memory metadata +structures must be kept in sync with what is on disk. + +There is also replication of some important structures across the filesystem. +All replicas must be updated when their original copy is updated. There is also +provisioning for snapshotting which must be kept in mind, although it should not +affect this implementation unless we allow users to create filesystem snapshots. + +Ext4 also introduced journaling (jbd2). The journal must be updated +appropriately. + +#### Performance + +To improve performance we should implement a buffer cache, and optionally, read +ahead for small files. While doing so we must also keep in mind the memory usage +and have a reasonable cap on how much file data we want to hold in memory. + +#### Features + +Our current implementation will work with most ext4 filesystems for readonly +purposed. However, the following features are not supported yet: + +- Journal +- Snapshotting +- Extended Attributes +- Hash Tree Directories +- Meta Block Groups +- Multiple Mount Protection +- Bigalloc diff --git a/pkg/sentry/fs/ext/assets/README.md b/pkg/sentry/fsimpl/ext/assets/README.md index 6f1e81b3a..6f1e81b3a 100644 --- a/pkg/sentry/fs/ext/assets/README.md +++ b/pkg/sentry/fsimpl/ext/assets/README.md diff --git a/pkg/sentry/fs/ext/assets/bigfile.txt b/pkg/sentry/fsimpl/ext/assets/bigfile.txt index 3857cf516..3857cf516 100644 --- a/pkg/sentry/fs/ext/assets/bigfile.txt +++ b/pkg/sentry/fsimpl/ext/assets/bigfile.txt diff --git a/pkg/sentry/fs/ext/assets/file.txt b/pkg/sentry/fsimpl/ext/assets/file.txt index 980a0d5f1..980a0d5f1 100644 --- a/pkg/sentry/fs/ext/assets/file.txt +++ b/pkg/sentry/fsimpl/ext/assets/file.txt diff --git a/pkg/sentry/fs/ext/assets/symlink.txt b/pkg/sentry/fsimpl/ext/assets/symlink.txt index 4c330738c..4c330738c 120000 --- a/pkg/sentry/fs/ext/assets/symlink.txt +++ b/pkg/sentry/fsimpl/ext/assets/symlink.txt diff --git a/pkg/sentry/fs/ext/assets/tiny.ext2 b/pkg/sentry/fsimpl/ext/assets/tiny.ext2 Binary files differindex 381ade9bf..381ade9bf 100644 --- a/pkg/sentry/fs/ext/assets/tiny.ext2 +++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext2 diff --git a/pkg/sentry/fs/ext/assets/tiny.ext3 b/pkg/sentry/fsimpl/ext/assets/tiny.ext3 Binary files differindex 0e97a324c..0e97a324c 100644 --- a/pkg/sentry/fs/ext/assets/tiny.ext3 +++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext3 diff --git a/pkg/sentry/fs/ext/assets/tiny.ext4 b/pkg/sentry/fsimpl/ext/assets/tiny.ext4 Binary files differindex a6859736d..a6859736d 100644 --- a/pkg/sentry/fs/ext/assets/tiny.ext4 +++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext4 diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD new file mode 100644 index 000000000..9fddb4c4c --- /dev/null +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_stateify:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "benchmark_test", + size = "small", + srcs = ["benchmark_test.go"], + deps = [ + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/fsimpl/ext", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + ], +) diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go new file mode 100644 index 000000000..10a8083a0 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -0,0 +1,193 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// These benchmarks emulate memfs benchmarks. Ext4 images must be created +// before this benchmark is run using the `make_deep_ext4.sh` script at +// /tmp/image-{depth}.ext4 for all the depths tested below. +package benchmark_test + +import ( + "fmt" + "os" + "runtime" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +var depths = []int{1, 2, 3, 8, 64, 100} + +const filename = "file.txt" + +// setUp opens imagePath as an ext Filesystem and returns all necessary +// elements required to run tests. If error is nil, it also returns a tear +// down function which must be called after the test is run for clean up. +func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) { + f, err := os.Open(imagePath) + if err != nil { + return nil, nil, nil, nil, err + } + + ctx := contexttest.Context(b) + creds := auth.CredentialsFromContext(ctx) + + // Create VFS. + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + if err != nil { + f.Close() + return nil, nil, nil, nil, err + } + + root := mntns.Root() + + tearDown := func() { + root.DecRef() + + if err := f.Close(); err != nil { + b.Fatalf("tearDown failed: %v", err) + } + } + return ctx, vfsObj, &root, tearDown, nil +} + +// mount mounts extfs at the path operation passed. Returns a tear down +// function which must be called after the test is run for clean up. +func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vfs.PathOperation) func() { + b.Helper() + + f, err := os.Open(imagePath) + if err != nil { + b.Fatalf("could not open image at %s: %v", imagePath, err) + } + + ctx := contexttest.Context(b) + creds := auth.CredentialsFromContext(ctx) + + if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil { + b.Fatalf("failed to mount tmpfs submount: %v", err) + } + return func() { + if err := f.Close(); err != nil { + b.Fatalf("tearDown failed: %v", err) + } + } +} + +// BenchmarkVFS2Ext4fsStat emulates BenchmarkVFS2MemfsStat. +func BenchmarkVFS2Ext4fsStat(b *testing.B) { + for _, depth := range depths { + b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { + ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", depth)) + if err != nil { + b.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + creds := auth.CredentialsFromContext(ctx) + var filePathBuilder strings.Builder + filePathBuilder.WriteByte('/') + for i := 1; i <= depth; i++ { + filePathBuilder.WriteString(fmt.Sprintf("%d", i)) + filePathBuilder.WriteByte('/') + } + filePathBuilder.WriteString(filename) + filePath := filePathBuilder.String() + + runtime.GC() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{ + Root: *root, + Start: *root, + Pathname: filePath, + FollowFinalSymlink: true, + }, &vfs.StatOptions{}) + if err != nil { + b.Fatalf("stat(%q) failed: %v", filePath, err) + } + // Sanity check. + if stat.Size > 0 { + b.Fatalf("got wrong file size (%d)", stat.Size) + } + } + }) + } +} + +// BenchmarkVFS2ExtfsMountStat emulates BenchmarkVFS2MemfsMountStat. +func BenchmarkVFS2ExtfsMountStat(b *testing.B) { + for _, depth := range depths { + b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) { + // Create root extfs with depth 1 so we can mount extfs again at /1/. + ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", 1)) + if err != nil { + b.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + creds := auth.CredentialsFromContext(ctx) + mountPointName := "/1/" + pop := vfs.PathOperation{ + Root: *root, + Start: *root, + Pathname: mountPointName, + } + + // Save the mount point for later use. + mountPoint, err := vfsfs.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{}) + if err != nil { + b.Fatalf("failed to walk to mount point: %v", err) + } + defer mountPoint.DecRef() + + // Create extfs submount. + mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop) + defer mountTearDown() + + var filePathBuilder strings.Builder + filePathBuilder.WriteString(mountPointName) + for i := 1; i <= depth; i++ { + filePathBuilder.WriteString(fmt.Sprintf("%d", i)) + filePathBuilder.WriteByte('/') + } + filePathBuilder.WriteString(filename) + filePath := filePathBuilder.String() + + runtime.GC() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{ + Root: *root, + Start: *root, + Pathname: filePath, + FollowFinalSymlink: true, + }, &vfs.StatOptions{}) + if err != nil { + b.Fatalf("stat(%q) failed: %v", filePath, err) + } + // Sanity check. touch(1) always creates files of size 0 (empty). + if stat.Size > 0 { + b.Fatalf("got wrong file size (%d)", stat.Size) + } + } + }) + } +} diff --git a/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh new file mode 100755 index 000000000..d0910da1f --- /dev/null +++ b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script creates an ext4 image with $1 depth of directories and a file in +# the inner most directory. The created file is at path /1/2/.../depth/file.txt. +# The ext4 image is written to $2. The image is temporarily mounted at +# /tmp/mountpoint. This script must be run with sudo privileges. + +# Usage: +# sudo bash make_deep_ext4.sh {depth} {output path} + +# Check positional arguments. +if [ "$#" -ne 2 ]; then + echo "Usage: sudo bash make_deep_ext4.sh {depth} {output path}" + exit 1 +fi + +# Make sure depth is a non-negative number. +if ! [[ "$1" =~ ^[0-9]+$ ]]; then + echo "Depth must be a non-negative number." + exit 1 +fi + +# Create a 1 MB filesystem image at the requested output path. +rm -f $2 +fallocate -l 1M $2 +if [ $? -ne 0 ]; then + echo "fallocate failed" + exit $? +fi + +# Convert that blank into an ext4 image. +mkfs.ext4 -j $2 +if [ $? -ne 0 ]; then + echo "mkfs.ext4 failed" + exit $? +fi + +# Mount the image. +MOUNTPOINT=/tmp/mountpoint +mkdir -p $MOUNTPOINT +mount -o loop $2 $MOUNTPOINT +if [ $? -ne 0 ]; then + echo "mount failed" + exit $? +fi + +# Create nested directories and the file. +if [ "$1" -eq 0 ]; then + FILEPATH=$MOUNTPOINT/file.txt +else + FILEPATH=$MOUNTPOINT/$(seq -s '/' 1 $1)/file.txt +fi +mkdir -p $(dirname $FILEPATH) || exit +touch $FILEPATH + +# Clean up. +umount $MOUNTPOINT +rm -rf $MOUNTPOINT diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go new file mode 100644 index 000000000..cea89bcd9 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -0,0 +1,200 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "io" + "math" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + // numDirectBlks is the number of direct blocks in ext block map inodes. + numDirectBlks = 12 +) + +// blockMapFile is a type of regular file which uses direct/indirect block +// addressing to store file data. This was deprecated in ext4. +type blockMapFile struct { + regFile regularFile + + // directBlks are the direct blocks numbers. The physical blocks pointed by + // these holds file data. Contains file blocks 0 to 11. + directBlks [numDirectBlks]uint32 + + // indirectBlk is the physical block which contains (blkSize/4) direct block + // numbers (as uint32 integers). + indirectBlk uint32 + + // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect + // block numbers (as uint32 integers). + doubleIndirectBlk uint32 + + // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly + // indirect block numbers (as uint32 integers). + tripleIndirectBlk uint32 + + // coverage at (i)th index indicates the amount of file data a node at + // height (i) covers. Height 0 is the direct block. + coverage [4]uint64 +} + +// Compiles only if blockMapFile implements io.ReaderAt. +var _ io.ReaderAt = (*blockMapFile)(nil) + +// newBlockMapFile is the blockMapFile constructor. It initializes the file to +// physical blocks map with (at most) the first 12 (direct) blocks. +func newBlockMapFile(regFile regularFile) (*blockMapFile, error) { + file := &blockMapFile{regFile: regFile} + file.regFile.impl = file + + for i := uint(0); i < 4; i++ { + file.coverage[i] = getCoverage(regFile.inode.blkSize, i) + } + + blkMap := regFile.inode.diskInode.Data() + binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks) + binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk) + binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk) + binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk) + return file, nil +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { + if len(dst) == 0 { + return 0, nil + } + + if off < 0 { + return 0, syserror.EINVAL + } + + offset := uint64(off) + size := f.regFile.inode.diskInode.Size() + if offset >= size { + return 0, io.EOF + } + + // dirBlksEnd is the file offset until which direct blocks cover file data. + // Direct blocks cover 0 <= file offset < dirBlksEnd. + dirBlksEnd := numDirectBlks * f.coverage[0] + + // indirBlkEnd is the file offset until which the indirect block covers file + // data. The indirect block covers dirBlksEnd <= file offset < indirBlkEnd. + indirBlkEnd := dirBlksEnd + f.coverage[1] + + // doubIndirBlkEnd is the file offset until which the double indirect block + // covers file data. The double indirect block covers the range + // indirBlkEnd <= file offset < doubIndirBlkEnd. + doubIndirBlkEnd := indirBlkEnd + f.coverage[2] + + read := 0 + toRead := len(dst) + if uint64(toRead)+offset > size { + toRead = int(size - offset) + } + for read < toRead { + var err error + var curR int + + // Figure out which block to delegate the read to. + switch { + case offset < dirBlksEnd: + // Direct block. + curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:]) + case offset < indirBlkEnd: + // Indirect block. + curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:]) + case offset < doubIndirBlkEnd: + // Doubly indirect block. + curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:]) + default: + // Triply indirect block. + curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:]) + } + + read += curR + offset += uint64(curR) + if err != nil { + return read, err + } + } + + if read < len(dst) { + return read, io.EOF + } + return read, nil +} + +// read is the recursive step of the ReadAt function. It relies on knowing the +// current node's location on disk (curPhyBlk) and its height in the block map +// tree. A height of 0 shows that the current node is actually holding file +// data. relFileOff tells the offset from which we need to start to reading +// under the current node. It is completely relative to the current node. +func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, dst []byte) (int, error) { + curPhyBlkOff := int64(curPhyBlk) * int64(f.regFile.inode.blkSize) + if height == 0 { + toRead := int(f.regFile.inode.blkSize - relFileOff) + if len(dst) < toRead { + toRead = len(dst) + } + + n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) + if n < toRead { + return n, syserror.EIO + } + return n, nil + } + + childCov := f.coverage[height-1] + startIdx := relFileOff / childCov + endIdx := f.regFile.inode.blkSize / 4 // This is exclusive. + wantEndIdx := (relFileOff + uint64(len(dst))) / childCov + wantEndIdx++ // Make this exclusive. + if wantEndIdx < endIdx { + endIdx = wantEndIdx + } + + read := 0 + curChildOff := relFileOff % childCov + for i := startIdx; i < endIdx; i++ { + var childPhyBlk uint32 + err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) + if err != nil { + return read, err + } + + n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:]) + read += n + if err != nil { + return read, err + } + + curChildOff = 0 + } + + return read, nil +} + +// getCoverage returns the number of bytes a node at the given height covers. +// Height 0 is the file data block itself. Height 1 is the indirect block. +// +// Formula: blkSize * ((blkSize / 4)^height) +func getCoverage(blkSize uint64, height uint) uint64 { + return blkSize * uint64(math.Pow(float64(blkSize/4), float64(height))) +} diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go new file mode 100644 index 000000000..213aa3919 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -0,0 +1,157 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "bytes" + "math/rand" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" +) + +// These consts are for mocking the block map tree. +const ( + mockBMBlkSize = uint32(16) + mockBMDiskSize = 2500 +) + +// TestBlockMapReader stress tests block map reader functionality. It performs +// random length reads from all possible positions in the block map structure. +func TestBlockMapReader(t *testing.T) { + mockBMFile, want := blockMapSetUp(t) + n := len(want) + + for from := 0; from < n; from++ { + got := make([]byte, n-from) + + if read, err := mockBMFile.ReadAt(got, int64(from)); err != nil { + t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err) + } + + if diff := cmp.Diff(got, want[from:]); diff != "" { + t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff) + } + } +} + +// blkNumGen is a number generator which gives block numbers for building the +// block map file on disk. It gives unique numbers in a random order which +// facilitates in creating an extremely fragmented filesystem. +type blkNumGen struct { + nums []uint32 +} + +// newBlkNumGen is the blkNumGen constructor. +func newBlkNumGen() *blkNumGen { + blkNums := &blkNumGen{} + lim := mockBMDiskSize / mockBMBlkSize + blkNums.nums = make([]uint32, lim) + for i := uint32(0); i < lim; i++ { + blkNums.nums[i] = i + } + + rand.Shuffle(int(lim), func(i, j int) { + blkNums.nums[i], blkNums.nums[j] = blkNums.nums[j], blkNums.nums[i] + }) + return blkNums +} + +// next returns the next random block number. +func (n *blkNumGen) next() uint32 { + ret := n.nums[0] + n.nums = n.nums[1:] + return ret +} + +// blockMapSetUp creates a mock disk and a block map file. It initializes the +// block map file with 12 direct block, 1 indirect block, 1 double indirect +// block and 1 triple indirect block (basically fill it till the rim). It +// initializes the disk to reflect the inode. Also returns the file data that +// the inode covers and that is written to disk. +func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { + mockDisk := make([]byte, mockBMDiskSize) + regFile := regularFile{ + inode: inode{ + diskInode: &disklayout.InodeNew{ + InodeOld: disklayout.InodeOld{ + SizeLo: getMockBMFileFize(), + }, + }, + dev: bytes.NewReader(mockDisk), + blkSize: uint64(mockBMBlkSize), + }, + } + + var fileData []byte + blkNums := newBlkNumGen() + var data []byte + + // Write the direct blocks. + for i := 0; i < numDirectBlks; i++ { + curBlkNum := blkNums.next() + data = binary.Marshal(data, binary.LittleEndian, curBlkNum) + fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...) + } + + // Write to indirect block. + indirectBlk := blkNums.next() + data = binary.Marshal(data, binary.LittleEndian, indirectBlk) + fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...) + + // Write to indirect block. + doublyIndirectBlk := blkNums.next() + data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk) + fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...) + + // Write to indirect block. + triplyIndirectBlk := blkNums.next() + data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk) + fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...) + + copy(regFile.inode.diskInode.Data(), data) + + mockFile, err := newBlockMapFile(regFile) + if err != nil { + t.Fatalf("newBlockMapFile failed: %v", err) + } + return mockFile, fileData +} + +// writeFileDataToBlock writes random bytes to the block on disk. +func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkNumGen) []byte { + if height == 0 { + start := blkNum * mockBMBlkSize + end := start + mockBMBlkSize + rand.Read(disk[start:end]) + return disk[start:end] + } + + var fileData []byte + for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 { + curBlkNum := blkNums.next() + copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum)) + fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...) + } + return fileData +} + +// getMockBMFileFize gets the size of the mock block map file which is used for +// testing. +func getMockBMFileFize() uint32 { + return uint32(numDirectBlks*getCoverage(uint64(mockBMBlkSize), 0) + getCoverage(uint64(mockBMBlkSize), 1) + getCoverage(uint64(mockBMBlkSize), 2) + getCoverage(uint64(mockBMBlkSize), 3)) +} diff --git a/pkg/sentry/fs/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 054fb42b6..054fb42b6 100644 --- a/pkg/sentry/fs/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go new file mode 100644 index 000000000..b51f3e18d --- /dev/null +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -0,0 +1,308 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// directory represents a directory inode. It holds the childList in memory. +type directory struct { + inode inode + + // mu serializes the changes to childList. + // Lock Order (outermost locks must be taken first): + // directory.mu + // filesystem.mu + mu sync.Mutex + + // childList is a list containing (1) child dirents and (2) fake dirents + // (with diskDirent == nil) that represent the iteration position of + // directoryFDs. childList is used to support directoryFD.IterDirents() + // efficiently. childList is protected by mu. + childList direntList + + // childMap maps the child's filename to the dirent structure stored in + // childList. This adds some data replication but helps in faster path + // traversal. For consistency, key == childMap[key].diskDirent.FileName(). + // Immutable. + childMap map[string]*dirent +} + +// newDirectroy is the directory constructor. +func newDirectroy(inode inode, newDirent bool) (*directory, error) { + file := &directory{inode: inode, childMap: make(map[string]*dirent)} + file.inode.impl = file + + // Initialize childList by reading dirents from the underlying file. + if inode.diskInode.Flags().Index { + // TODO(b/134676337): Support hash tree directories. Currently only the '.' + // and '..' entries are read in. + + // Users cannot navigate this hash tree directory yet. + log.Warningf("hash tree directory being used which is unsupported") + return file, nil + } + + // The dirents are organized in a linear array in the file data. + // Extract the file data and decode the dirents. + regFile, err := newRegularFile(inode) + if err != nil { + return nil, err + } + + // buf is used as scratch space for reading in dirents from disk and + // unmarshalling them into dirent structs. + buf := make([]byte, disklayout.DirentSize) + size := inode.diskInode.Size() + for off, inc := uint64(0), uint64(0); off < size; off += inc { + toRead := size - off + if toRead > disklayout.DirentSize { + toRead = disklayout.DirentSize + } + if n, err := regFile.impl.ReadAt(buf[:toRead], int64(off)); uint64(n) < toRead { + return nil, err + } + + var curDirent dirent + if newDirent { + curDirent.diskDirent = &disklayout.DirentNew{} + } else { + curDirent.diskDirent = &disklayout.DirentOld{} + } + binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent) + + if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 { + // Inode number and name length fields being set to 0 is used to indicate + // an unused dirent. + file.childList.PushBack(&curDirent) + file.childMap[curDirent.diskDirent.FileName()] = &curDirent + } + + // The next dirent is placed exactly after this dirent record on disk. + inc = uint64(curDirent.diskDirent.RecordSize()) + } + + return file, nil +} + +func (i *inode) isDir() bool { + _, ok := i.impl.(*directory) + return ok +} + +// dirent is the directory.childList node. +type dirent struct { + diskDirent disklayout.Dirent + + // direntEntry links dirents into their parent directory.childList. + direntEntry +} + +// directoryFD represents a directory file description. It implements +// vfs.FileDescriptionImpl. +type directoryFD struct { + fileDescription + vfs.DirectoryFileDescriptionDefaultImpl + + // Protected by directory.mu. + iter *dirent + off int64 +} + +// Compiles only if directoryFD implements vfs.FileDescriptionImpl. +var _ vfs.FileDescriptionImpl = (*directoryFD)(nil) + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *directoryFD) Release() { + if fd.iter == nil { + return + } + + dir := fd.inode().impl.(*directory) + dir.mu.Lock() + dir.childList.Remove(fd.iter) + dir.mu.Unlock() + fd.iter = nil +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + extfs := fd.filesystem() + dir := fd.inode().impl.(*directory) + + dir.mu.Lock() + defer dir.mu.Unlock() + + // Ensure that fd.iter exists and is not linked into dir.childList. + var child *dirent + if fd.iter == nil { + // Start iteration at the beginning of dir. + child = dir.childList.Front() + fd.iter = &dirent{} + } else { + // Continue iteration from where we left off. + child = fd.iter.Next() + dir.childList.Remove(fd.iter) + } + for ; child != nil; child = child.Next() { + // Skip other directoryFD iterators. + if child.diskDirent != nil { + childType, ok := child.diskDirent.FileType() + if !ok { + // We will need to read the inode off disk. Do not increment + // ref count here because this inode is not being added to the + // dentry tree. + extfs.mu.Lock() + childInode, err := extfs.getOrCreateInodeLocked(child.diskDirent.Inode()) + extfs.mu.Unlock() + if err != nil { + // Usage of the file description after the error is + // undefined. This implementation would continue reading + // from the next dirent. + fd.off++ + dir.childList.InsertAfter(child, fd.iter) + return err + } + childType = fs.ToInodeType(childInode.diskInode.Mode().FileType()) + } + + if !cb.Handle(vfs.Dirent{ + Name: child.diskDirent.FileName(), + Type: fs.ToDirentType(childType), + Ino: uint64(child.diskDirent.Inode()), + Off: fd.off, + }) { + dir.childList.InsertBefore(child, fd.iter) + return nil + } + fd.off++ + } + } + dir.childList.PushBack(fd.iter) + return nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + if whence != linux.SEEK_SET && whence != linux.SEEK_CUR { + return 0, syserror.EINVAL + } + + dir := fd.inode().impl.(*directory) + + dir.mu.Lock() + defer dir.mu.Unlock() + + // Find resulting offset. + if whence == linux.SEEK_CUR { + offset += fd.off + } + + if offset < 0 { + // lseek(2) specifies that EINVAL should be returned if the resulting offset + // is negative. + return 0, syserror.EINVAL + } + + n := int64(len(dir.childMap)) + realWantOff := offset + if realWantOff > n { + realWantOff = n + } + realCurOff := fd.off + if realCurOff > n { + realCurOff = n + } + + // Ensure that fd.iter exists and is linked into dir.childList so we can + // intelligently seek from the optimal position. + if fd.iter == nil { + fd.iter = &dirent{} + dir.childList.PushFront(fd.iter) + } + + // Guess that iterating from the current position is optimal. + child := fd.iter + diff := realWantOff - realCurOff // Shows direction and magnitude of travel. + + // See if starting from the beginning or end is better. + abDiff := diff + if diff < 0 { + abDiff = -diff + } + if abDiff > realWantOff { + // Starting from the beginning is best. + child = dir.childList.Front() + diff = realWantOff + } else if abDiff > (n - realWantOff) { + // Starting from the end is best. + child = dir.childList.Back() + // (n - 1) because the last non-nil dirent represents the (n-1)th offset. + diff = realWantOff - (n - 1) + } + + for child != nil { + // Skip other directoryFD iterators. + if child.diskDirent != nil { + if diff == 0 { + if child != fd.iter { + dir.childList.Remove(fd.iter) + dir.childList.InsertBefore(child, fd.iter) + } + + fd.off = offset + return offset, nil + } + + if diff < 0 { + diff++ + child = child.Prev() + } else { + diff-- + child = child.Next() + } + continue + } + + if diff < 0 { + child = child.Prev() + } else { + child = child.Next() + } + } + + // Reaching here indicates that the offset is beyond the end of the childList. + dir.childList.Remove(fd.iter) + dir.childList.PushBack(fd.iter) + fd.off = offset + return offset, nil +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { + // mmap(2) specifies that EACCESS should be returned for non-regular file fds. + return syserror.EACCES +} diff --git a/pkg/sentry/fs/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index dde15110d..907d35b7e 100644 --- a/pkg/sentry/fs/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -22,7 +22,7 @@ go_library( "superblock_old.go", "test_utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout", + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go index ad6f4fef8..ad6f4fef8 100644 --- a/pkg/sentry/fs/ext/disklayout/block_group.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go diff --git a/pkg/sentry/fs/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go index 3e16c76db..3e16c76db 100644 --- a/pkg/sentry/fs/ext/disklayout/block_group_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go diff --git a/pkg/sentry/fs/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go index 9a809197a..9a809197a 100644 --- a/pkg/sentry/fs/ext/disklayout/block_group_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go diff --git a/pkg/sentry/fs/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go index 0ef4294c0..0ef4294c0 100644 --- a/pkg/sentry/fs/ext/disklayout/block_group_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go diff --git a/pkg/sentry/fs/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go index 685bf57b8..417b6cf65 100644 --- a/pkg/sentry/fs/ext/disklayout/dirent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go @@ -21,6 +21,9 @@ import ( const ( // MaxFileName is the maximum length of an ext fs file's name. MaxFileName = 255 + + // DirentSize is the size of ext dirent structures. + DirentSize = 263 ) var ( diff --git a/pkg/sentry/fs/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go index 29ae4a5c2..29ae4a5c2 100644 --- a/pkg/sentry/fs/ext/disklayout/dirent_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go diff --git a/pkg/sentry/fs/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go index 6fff12a6e..6fff12a6e 100644 --- a/pkg/sentry/fs/ext/disklayout/dirent_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go diff --git a/pkg/sentry/fs/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go index cc6dff2c9..934919f8a 100644 --- a/pkg/sentry/fs/ext/disklayout/dirent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go @@ -21,8 +21,6 @@ import ( // TestDirentSize tests that the dirent structs are of the correct // size. func TestDirentSize(t *testing.T) { - want := uintptr(263) - - assertSize(t, DirentOld{}, want) - assertSize(t, DirentNew{}, want) + assertSize(t, DirentOld{}, uintptr(DirentSize)) + assertSize(t, DirentNew{}, uintptr(DirentSize)) } diff --git a/pkg/sentry/fs/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go index bdf4e2132..bdf4e2132 100644 --- a/pkg/sentry/fs/ext/disklayout/disklayout.go +++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go diff --git a/pkg/sentry/fs/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go index 567523d32..567523d32 100644 --- a/pkg/sentry/fs/ext/disklayout/extent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go diff --git a/pkg/sentry/fs/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go index b0fad9b71..b0fad9b71 100644 --- a/pkg/sentry/fs/ext/disklayout/extent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go diff --git a/pkg/sentry/fs/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go index 88ae913f5..88ae913f5 100644 --- a/pkg/sentry/fs/ext/disklayout/inode.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go diff --git a/pkg/sentry/fs/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go index 8f9f574ce..8f9f574ce 100644 --- a/pkg/sentry/fs/ext/disklayout/inode_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go diff --git a/pkg/sentry/fs/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go index db25b11b6..db25b11b6 100644 --- a/pkg/sentry/fs/ext/disklayout/inode_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go diff --git a/pkg/sentry/fs/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go index dd03ee50e..dd03ee50e 100644 --- a/pkg/sentry/fs/ext/disklayout/inode_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go diff --git a/pkg/sentry/fs/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go index 7a337a5e0..8bb327006 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go @@ -221,7 +221,7 @@ func CompatFeaturesFromInt(f uint32) CompatFeatures { // This is not exhaustive, unused features are not listed. const ( // SbDirentFileType indicates that directory entries record the file type. - // We should use struct ext4_dir_entry_2 for dirents then. + // We should use struct DirentNew for dirents then. SbDirentFileType = 0x2 // SbRecovery indicates that the filesystem needs recovery. diff --git a/pkg/sentry/fs/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go index 53e515fd3..53e515fd3 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go diff --git a/pkg/sentry/fs/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go index 7c1053fb4..7c1053fb4 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go diff --git a/pkg/sentry/fs/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go index 9221e0251..9221e0251 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go diff --git a/pkg/sentry/fs/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go index 463b5ba21..463b5ba21 100644 --- a/pkg/sentry/fs/ext/disklayout/superblock_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go diff --git a/pkg/sentry/fs/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go index 9c63f04c0..9c63f04c0 100644 --- a/pkg/sentry/fs/ext/disklayout/test_utils.go +++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go new file mode 100644 index 000000000..f10accafc --- /dev/null +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -0,0 +1,135 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ext implements readonly ext(2/3/4) filesystems. +package ext + +import ( + "errors" + "fmt" + "io" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +// Compiles only if FilesystemType implements vfs.FilesystemType. +var _ vfs.FilesystemType = (*FilesystemType)(nil) + +// getDeviceFd returns an io.ReaderAt to the underlying device. +// Currently there are two ways of mounting an ext(2/3/4) fs: +// 1. Specify a mount with our internal special MountType in the OCI spec. +// 2. Expose the device to the container and mount it from application layer. +func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) { + if opts.InternalData == nil { + // User mount call. + // TODO(b/134676337): Open the device specified by `source` and return that. + panic("unimplemented") + } + + // NewFilesystem call originated from within the sentry. + devFd, ok := opts.InternalData.(int) + if !ok { + return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device") + } + + if devFd < 0 { + return nil, fmt.Errorf("ext device file descriptor is not valid: %d", devFd) + } + + // The fd.ReadWriter returned from fd.NewReadWriter() does not take ownership + // of the file descriptor and hence will not close it when it is garbage + // collected. + return fd.NewReadWriter(devFd), nil +} + +// isCompatible checks if the superblock has feature sets which are compatible. +// We only need to check the superblock incompatible feature set since we are +// mounting readonly. We will also need to check readonly compatible feature +// set when mounting for read/write. +func isCompatible(sb disklayout.SuperBlock) bool { + // Please note that what is being checked is limited based on the fact that we + // are mounting readonly and that we are not journaling. When mounting + // read/write or with a journal, this must be reevaluated. + incompatFeatures := sb.IncompatibleFeatures() + if incompatFeatures.MetaBG { + log.Warningf("ext fs: meta block groups are not supported") + return false + } + if incompatFeatures.MMP { + log.Warningf("ext fs: multiple mount protection is not supported") + return false + } + if incompatFeatures.Encrypted { + log.Warningf("ext fs: encrypted inodes not supported") + return false + } + if incompatFeatures.InlineData { + log.Warningf("ext fs: inline files not supported") + return false + } + return true +} + +// NewFilesystem implements vfs.FilesystemType.NewFilesystem. +func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + // TODO(b/134676337): Ensure that the user is mounting readonly. If not, + // EACCESS should be returned according to mount(2). Filesystem independent + // flags (like readonly) are currently not available in pkg/sentry/vfs. + + dev, err := getDeviceFd(source, opts) + if err != nil { + return nil, nil, err + } + + fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} + fs.vfsfs.Init(&fs) + fs.sb, err = readSuperBlock(dev) + if err != nil { + return nil, nil, err + } + + if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { + // mount(2) specifies that EINVAL should be returned if the superblock is + // invalid. + return nil, nil, syserror.EINVAL + } + + // Refuse to mount if the filesystem is incompatible. + if !isCompatible(fs.sb) { + return nil, nil, syserror.EINVAL + } + + fs.bgs, err = readBlockGroups(dev, fs.sb) + if err != nil { + return nil, nil, err + } + + rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode) + if err != nil { + return nil, nil, err + } + rootInode.incRef() + + return &fs.vfsfs, &newDentry(rootInode).vfsd, nil +} diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go new file mode 100644 index 000000000..49b57a2d6 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -0,0 +1,917 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "io" + "os" + "path" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + + "gvisor.dev/gvisor/runsc/test/testutil" +) + +const ( + assetsDir = "pkg/sentry/fsimpl/ext/assets" +) + +var ( + ext2ImagePath = path.Join(assetsDir, "tiny.ext2") + ext3ImagePath = path.Join(assetsDir, "tiny.ext3") + ext4ImagePath = path.Join(assetsDir, "tiny.ext4") +) + +// setUp opens imagePath as an ext Filesystem and returns all necessary +// elements required to run tests. If error is non-nil, it also returns a tear +// down function which must be called after the test is run for clean up. +func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) { + localImagePath, err := testutil.FindFile(imagePath) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err) + } + + f, err := os.Open(localImagePath) + if err != nil { + return nil, nil, nil, nil, err + } + + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + + // Create VFS. + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + if err != nil { + f.Close() + return nil, nil, nil, nil, err + } + + root := mntns.Root() + + tearDown := func() { + root.DecRef() + + if err := f.Close(); err != nil { + t.Fatalf("tearDown failed: %v", err) + } + } + return ctx, vfsObj, &root, tearDown, nil +} + +// TODO(b/134676337): Test vfs.FilesystemImpl.ReadlinkAt and +// vfs.FilesystemImpl.StatFSAt which are not implemented in +// vfs.VirtualFilesystem yet. + +// TestSeek tests vfs.FileDescriptionImpl.Seek functionality. +func TestSeek(t *testing.T) { + type seekTest struct { + name string + image string + path string + } + + tests := []seekTest{ + { + name: "ext4 root dir seek", + image: ext4ImagePath, + path: "/", + }, + { + name: "ext3 root dir seek", + image: ext3ImagePath, + path: "/", + }, + { + name: "ext2 root dir seek", + image: ext2ImagePath, + path: "/", + }, + { + name: "ext4 reg file seek", + image: ext4ImagePath, + path: "/file.txt", + }, + { + name: "ext3 reg file seek", + image: ext3ImagePath, + path: "/file.txt", + }, + { + name: "ext2 reg file seek", + image: ext2ImagePath, + path: "/file.txt", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, vfsfs, root, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + fd, err := vfsfs.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt failed: %v", err) + } + + if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { + t.Errorf("expected seek position 0, got %d and error %v", n, err) + } + + stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err) + } + + // We should be able to seek beyond the end of file. + size := int64(stat.Size) + if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { + t.Errorf("expected seek position %d, got %d and error %v", size, n, err) + } + + // EINVAL should be returned if the resulting offset is negative. + if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + t.Errorf("expected error EINVAL but got %v", err) + } + + if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { + t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err) + } + + // Make sure negative offsets work with SEEK_CUR. + if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { + t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) + } + + // EINVAL should be returned if the resulting offset is negative. + if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + t.Errorf("expected error EINVAL but got %v", err) + } + + // Make sure SEEK_END works with regular files. + switch fd.Impl().(type) { + case *regularFileFD: + // Seek back to 0. + if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { + t.Errorf("expected seek position %d, got %d and error %v", 0, n, err) + } + + // Seek forward beyond EOF. + if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { + t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) + } + + // EINVAL should be returned if the resulting offset is negative. + if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + t.Errorf("expected error EINVAL but got %v", err) + } + } + }) + } +} + +// TestStatAt tests filesystem.StatAt functionality. +func TestStatAt(t *testing.T) { + type statAtTest struct { + name string + image string + path string + want linux.Statx + } + + tests := []statAtTest{ + { + name: "ext4 statx small file", + image: ext4ImagePath, + path: "/file.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0644 | linux.ModeRegular, + Size: 13, + }, + }, + { + name: "ext3 statx small file", + image: ext3ImagePath, + path: "/file.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0644 | linux.ModeRegular, + Size: 13, + }, + }, + { + name: "ext2 statx small file", + image: ext2ImagePath, + path: "/file.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0644 | linux.ModeRegular, + Size: 13, + }, + }, + { + name: "ext4 statx big file", + image: ext4ImagePath, + path: "/bigfile.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0644 | linux.ModeRegular, + Size: 13042, + }, + }, + { + name: "ext3 statx big file", + image: ext3ImagePath, + path: "/bigfile.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0644 | linux.ModeRegular, + Size: 13042, + }, + }, + { + name: "ext2 statx big file", + image: ext2ImagePath, + path: "/bigfile.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0644 | linux.ModeRegular, + Size: 13042, + }, + }, + { + name: "ext4 statx symlink file", + image: ext4ImagePath, + path: "/symlink.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0777 | linux.ModeSymlink, + Size: 8, + }, + }, + { + name: "ext3 statx symlink file", + image: ext3ImagePath, + path: "/symlink.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0777 | linux.ModeSymlink, + Size: 8, + }, + }, + { + name: "ext2 statx symlink file", + image: ext2ImagePath, + path: "/symlink.txt", + want: linux.Statx{ + Blksize: 0x400, + Nlink: 1, + UID: 0, + GID: 0, + Mode: 0777 | linux.ModeSymlink, + Size: 8, + }, + }, + } + + // Ignore the fields that are not supported by filesystem.StatAt yet and + // those which are likely to change as the image does. + ignoredFields := map[string]bool{ + "Attributes": true, + "AttributesMask": true, + "Atime": true, + "Blocks": true, + "Btime": true, + "Ctime": true, + "DevMajor": true, + "DevMinor": true, + "Ino": true, + "Mask": true, + "Mtime": true, + "RdevMajor": true, + "RdevMinor": true, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, vfsfs, root, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + got, err := vfsfs.StatAt(ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path}, + &vfs.StatOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.StatAt failed for file %s in image %s: %v", test.path, test.image, err) + } + + cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { + _, ok := ignoredFields[p.String()] + return ok + }, cmp.Ignore()) + if diff := cmp.Diff(got, test.want, cmpIgnoreFields, cmpopts.IgnoreUnexported(linux.Statx{})); diff != "" { + t.Errorf("stat mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestRead tests the read functionality for vfs file descriptions. +func TestRead(t *testing.T) { + type readTest struct { + name string + image string + absPath string + } + + tests := []readTest{ + { + name: "ext4 read small file", + image: ext4ImagePath, + absPath: "/file.txt", + }, + { + name: "ext3 read small file", + image: ext3ImagePath, + absPath: "/file.txt", + }, + { + name: "ext2 read small file", + image: ext2ImagePath, + absPath: "/file.txt", + }, + { + name: "ext4 read big file", + image: ext4ImagePath, + absPath: "/bigfile.txt", + }, + { + name: "ext3 read big file", + image: ext3ImagePath, + absPath: "/bigfile.txt", + }, + { + name: "ext2 read big file", + image: ext2ImagePath, + absPath: "/bigfile.txt", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, vfsfs, root, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + fd, err := vfsfs.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.absPath}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt failed: %v", err) + } + + // Get a local file descriptor and compare its functionality with a vfs file + // description for the same file. + localFile, err := testutil.FindFile(path.Join(assetsDir, test.absPath)) + if err != nil { + t.Fatalf("testutil.FindFile failed for %s: %v", test.absPath, err) + } + + f, err := os.Open(localFile) + if err != nil { + t.Fatalf("os.Open failed for %s: %v", localFile, err) + } + defer f.Close() + + // Read the entire file by reading one byte repeatedly. Doing this stress + // tests the underlying file reader implementation. + got := make([]byte, 1) + want := make([]byte, 1) + for { + n, err := f.Read(want) + fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) + + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("file data mismatch (-want +got):\n%s", diff) + } + + // Make sure there is no more file data left after getting EOF. + if n == 0 || err == io.EOF { + if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { + t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image) + } + + break + } + + if err != nil { + t.Fatalf("read failed: %v", err) + } + } + }) + } +} + +// iterDirentsCb is a simple callback which just keeps adding the dirents to an +// internal list. Implements vfs.IterDirentsCallback. +type iterDirentsCb struct { + dirents []vfs.Dirent +} + +// Compiles only if iterDirentCb implements vfs.IterDirentsCallback. +var _ vfs.IterDirentsCallback = (*iterDirentsCb)(nil) + +// newIterDirentsCb is the iterDirent +func newIterDirentCb() *iterDirentsCb { + return &iterDirentsCb{dirents: make([]vfs.Dirent, 0)} +} + +// Handle implements vfs.IterDirentsCallback.Handle. +func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) bool { + cb.dirents = append(cb.dirents, dirent) + return true +} + +// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality. +func TestIterDirents(t *testing.T) { + type iterDirentTest struct { + name string + image string + path string + want []vfs.Dirent + } + + wantDirents := []vfs.Dirent{ + vfs.Dirent{ + Name: ".", + Type: linux.DT_DIR, + }, + vfs.Dirent{ + Name: "..", + Type: linux.DT_DIR, + }, + vfs.Dirent{ + Name: "lost+found", + Type: linux.DT_DIR, + }, + vfs.Dirent{ + Name: "file.txt", + Type: linux.DT_REG, + }, + vfs.Dirent{ + Name: "bigfile.txt", + Type: linux.DT_REG, + }, + vfs.Dirent{ + Name: "symlink.txt", + Type: linux.DT_LNK, + }, + } + tests := []iterDirentTest{ + { + name: "ext4 root dir iteration", + image: ext4ImagePath, + path: "/", + want: wantDirents, + }, + { + name: "ext3 root dir iteration", + image: ext3ImagePath, + path: "/", + want: wantDirents, + }, + { + name: "ext2 root dir iteration", + image: ext2ImagePath, + path: "/", + want: wantDirents, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, vfsfs, root, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + fd, err := vfsfs.OpenAt( + ctx, + auth.CredentialsFromContext(ctx), + &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path}, + &vfs.OpenOptions{}, + ) + if err != nil { + t.Fatalf("vfsfs.OpenAt failed: %v", err) + } + + cb := &iterDirentsCb{} + if err = fd.Impl().IterDirents(ctx, cb); err != nil { + t.Fatalf("dir fd.IterDirents() failed: %v", err) + } + + sort.Slice(cb.dirents, func(i int, j int) bool { return cb.dirents[i].Name < cb.dirents[j].Name }) + sort.Slice(test.want, func(i int, j int) bool { return test.want[i].Name < test.want[j].Name }) + + // Ignore the inode number and offset of dirents because those are likely to + // change as the underlying image changes. + cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { + return p.String() == "Ino" || p.String() == "Off" + }, cmp.Ignore()) + if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" { + t.Errorf("dirents mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestRootDir tests that the root directory inode is correctly initialized and +// returned from setUp. +func TestRootDir(t *testing.T) { + type inodeProps struct { + Mode linux.FileMode + UID auth.KUID + GID auth.KGID + Size uint64 + InodeSize uint16 + Links uint16 + Flags disklayout.InodeFlags + } + + type rootDirTest struct { + name string + image string + wantInode inodeProps + } + + tests := []rootDirTest{ + { + name: "ext4 root dir", + image: ext4ImagePath, + wantInode: inodeProps{ + Mode: linux.ModeDirectory | 0755, + Size: 0x400, + InodeSize: 0x80, + Links: 3, + Flags: disklayout.InodeFlags{Extents: true}, + }, + }, + { + name: "ext3 root dir", + image: ext3ImagePath, + wantInode: inodeProps{ + Mode: linux.ModeDirectory | 0755, + Size: 0x400, + InodeSize: 0x80, + Links: 3, + }, + }, + { + name: "ext2 root dir", + image: ext2ImagePath, + wantInode: inodeProps{ + Mode: linux.ModeDirectory | 0755, + Size: 0x400, + InodeSize: 0x80, + Links: 3, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, _, vd, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + d, ok := vd.Dentry().Impl().(*dentry) + if !ok { + t.Fatalf("ext dentry of incorrect type: %T", vd.Dentry().Impl()) + } + + // Offload inode contents into local structs for comparison. + gotInode := inodeProps{ + Mode: d.inode.diskInode.Mode(), + UID: d.inode.diskInode.UID(), + GID: d.inode.diskInode.GID(), + Size: d.inode.diskInode.Size(), + InodeSize: d.inode.diskInode.InodeSize(), + Links: d.inode.diskInode.LinksCount(), + Flags: d.inode.diskInode.Flags(), + } + + if diff := cmp.Diff(gotInode, test.wantInode); diff != "" { + t.Errorf("inode mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestFilesystemInit tests that the filesystem superblock and block group +// descriptors are correctly read in and initialized. +func TestFilesystemInit(t *testing.T) { + // sb only contains the immutable properties of the superblock. + type sb struct { + InodesCount uint32 + BlocksCount uint64 + MaxMountCount uint16 + FirstDataBlock uint32 + BlockSize uint64 + BlocksPerGroup uint32 + ClusterSize uint64 + ClustersPerGroup uint32 + InodeSize uint16 + InodesPerGroup uint32 + BgDescSize uint16 + Magic uint16 + Revision disklayout.SbRevision + CompatFeatures disklayout.CompatFeatures + IncompatFeatures disklayout.IncompatFeatures + RoCompatFeatures disklayout.RoCompatFeatures + } + + // bg only contains the immutable properties of the block group descriptor. + type bg struct { + InodeTable uint64 + BlockBitmap uint64 + InodeBitmap uint64 + ExclusionBitmap uint64 + Flags disklayout.BGFlags + } + + type fsInitTest struct { + name string + image string + wantSb sb + wantBgs []bg + } + + tests := []fsInitTest{ + { + name: "ext4 filesystem init", + image: ext4ImagePath, + wantSb: sb{ + InodesCount: 0x10, + BlocksCount: 0x40, + MaxMountCount: 0xffff, + FirstDataBlock: 0x1, + BlockSize: 0x400, + BlocksPerGroup: 0x2000, + ClusterSize: 0x400, + ClustersPerGroup: 0x2000, + InodeSize: 0x80, + InodesPerGroup: 0x10, + BgDescSize: 0x40, + Magic: linux.EXT_SUPER_MAGIC, + Revision: disklayout.DynamicRev, + CompatFeatures: disklayout.CompatFeatures{ + ExtAttr: true, + ResizeInode: true, + DirIndex: true, + }, + IncompatFeatures: disklayout.IncompatFeatures{ + DirentFileType: true, + Extents: true, + Is64Bit: true, + FlexBg: true, + }, + RoCompatFeatures: disklayout.RoCompatFeatures{ + Sparse: true, + LargeFile: true, + HugeFile: true, + DirNlink: true, + ExtraIsize: true, + MetadataCsum: true, + }, + }, + wantBgs: []bg{ + { + InodeTable: 0x23, + BlockBitmap: 0x3, + InodeBitmap: 0x13, + Flags: disklayout.BGFlags{ + InodeZeroed: true, + }, + }, + }, + }, + { + name: "ext3 filesystem init", + image: ext3ImagePath, + wantSb: sb{ + InodesCount: 0x10, + BlocksCount: 0x40, + MaxMountCount: 0xffff, + FirstDataBlock: 0x1, + BlockSize: 0x400, + BlocksPerGroup: 0x2000, + ClusterSize: 0x400, + ClustersPerGroup: 0x2000, + InodeSize: 0x80, + InodesPerGroup: 0x10, + BgDescSize: 0x20, + Magic: linux.EXT_SUPER_MAGIC, + Revision: disklayout.DynamicRev, + CompatFeatures: disklayout.CompatFeatures{ + ExtAttr: true, + ResizeInode: true, + DirIndex: true, + }, + IncompatFeatures: disklayout.IncompatFeatures{ + DirentFileType: true, + }, + RoCompatFeatures: disklayout.RoCompatFeatures{ + Sparse: true, + LargeFile: true, + }, + }, + wantBgs: []bg{ + { + InodeTable: 0x5, + BlockBitmap: 0x3, + InodeBitmap: 0x4, + Flags: disklayout.BGFlags{ + InodeZeroed: true, + }, + }, + }, + }, + { + name: "ext2 filesystem init", + image: ext2ImagePath, + wantSb: sb{ + InodesCount: 0x10, + BlocksCount: 0x40, + MaxMountCount: 0xffff, + FirstDataBlock: 0x1, + BlockSize: 0x400, + BlocksPerGroup: 0x2000, + ClusterSize: 0x400, + ClustersPerGroup: 0x2000, + InodeSize: 0x80, + InodesPerGroup: 0x10, + BgDescSize: 0x20, + Magic: linux.EXT_SUPER_MAGIC, + Revision: disklayout.DynamicRev, + CompatFeatures: disklayout.CompatFeatures{ + ExtAttr: true, + ResizeInode: true, + DirIndex: true, + }, + IncompatFeatures: disklayout.IncompatFeatures{ + DirentFileType: true, + }, + RoCompatFeatures: disklayout.RoCompatFeatures{ + Sparse: true, + LargeFile: true, + }, + }, + wantBgs: []bg{ + { + InodeTable: 0x5, + BlockBitmap: 0x3, + InodeBitmap: 0x4, + Flags: disklayout.BGFlags{ + InodeZeroed: true, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, _, vd, tearDown, err := setUp(t, test.image) + if err != nil { + t.Fatalf("setUp failed: %v", err) + } + defer tearDown() + + fs, ok := vd.Mount().Filesystem().Impl().(*filesystem) + if !ok { + t.Fatalf("ext filesystem of incorrect type: %T", vd.Mount().Filesystem().Impl()) + } + + // Offload superblock and block group descriptors contents into + // local structs for comparison. + totalFreeInodes := uint32(0) + totalFreeBlocks := uint64(0) + gotSb := sb{ + InodesCount: fs.sb.InodesCount(), + BlocksCount: fs.sb.BlocksCount(), + MaxMountCount: fs.sb.MaxMountCount(), + FirstDataBlock: fs.sb.FirstDataBlock(), + BlockSize: fs.sb.BlockSize(), + BlocksPerGroup: fs.sb.BlocksPerGroup(), + ClusterSize: fs.sb.ClusterSize(), + ClustersPerGroup: fs.sb.ClustersPerGroup(), + InodeSize: fs.sb.InodeSize(), + InodesPerGroup: fs.sb.InodesPerGroup(), + BgDescSize: fs.sb.BgDescSize(), + Magic: fs.sb.Magic(), + Revision: fs.sb.Revision(), + CompatFeatures: fs.sb.CompatibleFeatures(), + IncompatFeatures: fs.sb.IncompatibleFeatures(), + RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(), + } + gotNumBgs := len(fs.bgs) + gotBgs := make([]bg, gotNumBgs) + for i := 0; i < gotNumBgs; i++ { + gotBgs[i].InodeTable = fs.bgs[i].InodeTable() + gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap() + gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap() + gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap() + gotBgs[i].Flags = fs.bgs[i].Flags() + + totalFreeInodes += fs.bgs[i].FreeInodesCount() + totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount()) + } + + if diff := cmp.Diff(gotSb, test.wantSb); diff != "" { + t.Errorf("superblock mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" { + t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" { + t.Errorf("total free inodes mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" { + t.Errorf("total free blocks mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go new file mode 100644 index 000000000..38b68a2d3 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -0,0 +1,237 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "io" + "sort" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/syserror" +) + +// extentFile is a type of regular file which uses extents to store file data. +type extentFile struct { + regFile regularFile + + // root is the root extent node. This lives in the 60 byte diskInode.Data(). + // Immutable. + root disklayout.ExtentNode +} + +// Compiles only if extentFile implements io.ReaderAt. +var _ io.ReaderAt = (*extentFile)(nil) + +// newExtentFile is the extent file constructor. It reads the entire extent +// tree into memory. +// TODO(b/134676337): Build extent tree on demand to reduce memory usage. +func newExtentFile(regFile regularFile) (*extentFile, error) { + file := &extentFile{regFile: regFile} + file.regFile.impl = file + err := file.buildExtTree() + if err != nil { + return nil, err + } + return file, nil +} + +// buildExtTree builds the extent tree by reading it from disk by doing +// running a simple DFS. It first reads the root node from the inode struct in +// memory. Then it recursively builds the rest of the tree by reading it off +// disk. +// +// Precondition: inode flag InExtents must be set. +func (f *extentFile) buildExtTree() error { + rootNodeData := f.regFile.inode.diskInode.Data() + + binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header) + + // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. + if f.root.Header.NumEntries > 4 { + // read(2) specifies that EINVAL should be returned if the file is unsuitable + // for reading. + return syserror.EINVAL + } + + f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries) + for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize { + var curEntry disklayout.ExtentEntry + if f.root.Header.Height == 0 { + // Leaf node. + curEntry = &disklayout.Extent{} + } else { + // Internal node. + curEntry = &disklayout.ExtentIdx{} + } + binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry) + f.root.Entries[i].Entry = curEntry + } + + // If this node is internal, perform DFS. + if f.root.Header.Height > 0 { + for i := uint16(0); i < f.root.Header.NumEntries; i++ { + var err error + if f.root.Entries[i].Node, err = f.buildExtTreeFromDisk(f.root.Entries[i].Entry); err != nil { + return err + } + } + } + + return nil +} + +// buildExtTreeFromDisk reads the extent tree nodes from disk and recursively +// builds the tree. Performs a simple DFS. It returns the ExtentNode pointed to +// by the ExtentEntry. +func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) { + var header disklayout.ExtentHeader + off := entry.PhysicalBlock() * f.regFile.inode.blkSize + err := readFromDisk(f.regFile.inode.dev, int64(off), &header) + if err != nil { + return nil, err + } + + entries := make([]disklayout.ExtentEntryPair, header.NumEntries) + for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize { + var curEntry disklayout.ExtentEntry + if header.Height == 0 { + // Leaf node. + curEntry = &disklayout.Extent{} + } else { + // Internal node. + curEntry = &disklayout.ExtentIdx{} + } + + err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry) + if err != nil { + return nil, err + } + entries[i].Entry = curEntry + } + + // If this node is internal, perform DFS. + if header.Height > 0 { + for i := uint16(0); i < header.NumEntries; i++ { + var err error + entries[i].Node, err = f.buildExtTreeFromDisk(entries[i].Entry) + if err != nil { + return nil, err + } + } + } + + return &disklayout.ExtentNode{header, entries}, nil +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (f *extentFile) ReadAt(dst []byte, off int64) (int, error) { + if len(dst) == 0 { + return 0, nil + } + + if off < 0 { + return 0, syserror.EINVAL + } + + if uint64(off) >= f.regFile.inode.diskInode.Size() { + return 0, io.EOF + } + + n, err := f.read(&f.root, uint64(off), dst) + if n < len(dst) && err == nil { + err = io.EOF + } + return n, err +} + +// read is the recursive step of extentFile.ReadAt which traverses the extent +// tree from the node passed and reads file data. +func (f *extentFile) read(node *disklayout.ExtentNode, off uint64, dst []byte) (int, error) { + // Perform a binary search for the node covering bytes starting at r.fileOff. + // A highly fragmented filesystem can have upto 340 entries and so linear + // search should be avoided. Finds the first entry which does not cover the + // file block we want and subtracts 1 to get the desired index. + fileBlk := uint32(off / f.regFile.inode.blkSize) + n := len(node.Entries) + found := sort.Search(n, func(i int) bool { + return node.Entries[i].Entry.FileBlock() > fileBlk + }) - 1 + + // We should be in this recursive step only if the data we want exists under + // the current node. + if found < 0 { + panic("searching for a file block in an extent entry which does not cover it") + } + + read := 0 + toRead := len(dst) + var curR int + var err error + for i := found; i < n && read < toRead; i++ { + if node.Header.Height == 0 { + curR, err = f.readFromExtent(node.Entries[i].Entry.(*disklayout.Extent), off, dst[read:]) + } else { + curR, err = f.read(node.Entries[i].Node, off, dst[read:]) + } + + read += curR + off += uint64(curR) + if err != nil { + return read, err + } + } + + return read, nil +} + +// readFromExtent reads file data from the extent. It takes advantage of the +// sequential nature of extents and reads file data from multiple blocks in one +// call. +// +// A non-nil error indicates that this is a partial read and there is probably +// more to read from this extent. The caller should propagate the error upward +// and not move to the next extent in the tree. +// +// A subsequent call to extentReader.Read should continue reading from where we +// left off as expected. +func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byte) (int, error) { + curFileBlk := uint32(off / f.regFile.inode.blkSize) + exFirstFileBlk := ex.FileBlock() + exLastFileBlk := exFirstFileBlk + uint32(ex.Length) // This is exclusive. + + // We should be in this recursive step only if the data we want exists under + // the current extent. + if curFileBlk < exFirstFileBlk || exLastFileBlk <= curFileBlk { + panic("searching for a file block in an extent which does not cover it") + } + + curPhyBlk := uint64(curFileBlk-exFirstFileBlk) + ex.PhysicalBlock() + readStart := curPhyBlk*f.regFile.inode.blkSize + (off % f.regFile.inode.blkSize) + + endPhyBlk := ex.PhysicalBlock() + uint64(ex.Length) + extentEnd := endPhyBlk * f.regFile.inode.blkSize // This is exclusive. + + toRead := int(extentEnd - readStart) + if len(dst) < toRead { + toRead = len(dst) + } + + n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart)) + if n < toRead { + return n, syserror.EIO + } + return n, nil +} diff --git a/pkg/sentry/fs/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index b3f342c8e..42d0a484b 100644 --- a/pkg/sentry/fs/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -16,17 +16,23 @@ package ext import ( "bytes" + "math/rand" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) -// TestExtentTree tests the extent tree building logic. +const ( + // mockExtentBlkSize is the mock block size used for testing. + // No block has more than 1 header + 4 entries. + mockExtentBlkSize = uint64(64) +) + +// The tree described below looks like: // -// Test tree: // 0.{Head}[Idx][Idx] // / \ // / \ @@ -44,12 +50,8 @@ import ( // // Please note that ext4 might not construct extent trees looking like this. // This is purely for testing the tree traversal logic. -func TestExtentTree(t *testing.T) { - blkSize := uint64(64) // No block has more than 1 header + 4 entries. - mockDisk := make([]byte, blkSize*10) - mockInode := &inode{diskInode: &disklayout.InodeNew{}} - - node3 := &disklayout.ExtentNode{ +var ( + node3 = &disklayout.ExtentNode{ Header: disklayout.ExtentHeader{ Magic: disklayout.ExtentMagic, NumEntries: 1, @@ -68,7 +70,7 @@ func TestExtentTree(t *testing.T) { }, } - node2 := &disklayout.ExtentNode{ + node2 = &disklayout.ExtentNode{ Header: disklayout.ExtentHeader{ Magic: disklayout.ExtentMagic, NumEntries: 1, @@ -86,7 +88,7 @@ func TestExtentTree(t *testing.T) { }, } - node1 := &disklayout.ExtentNode{ + node1 = &disklayout.ExtentNode{ Header: disklayout.ExtentHeader{ Magic: disklayout.ExtentMagic, NumEntries: 2, @@ -113,7 +115,7 @@ func TestExtentTree(t *testing.T) { }, } - node0 := &disklayout.ExtentNode{ + node0 = &disklayout.ExtentNode{ Header: disklayout.ExtentHeader{ Magic: disklayout.ExtentMagic, NumEntries: 2, @@ -137,22 +139,69 @@ func TestExtentTree(t *testing.T) { }, }, } +) - writeTree(mockInode, mockDisk, node0, blkSize) +// TestExtentReader stress tests extentReader functionality. It performs random +// length reads from all possible positions in the extent tree. +func TestExtentReader(t *testing.T) { + mockExtentFile, want := extentTreeSetUp(t, node0) + n := len(want) - r := bytes.NewReader(mockDisk) - if err := mockInode.buildExtTree(r, blkSize); err != nil { - t.Fatalf("inode.buildExtTree failed: %v", err) + for from := 0; from < n; from++ { + got := make([]byte, n-from) + + if read, err := mockExtentFile.ReadAt(got, int64(from)); err != nil { + t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err) + } + + if diff := cmp.Diff(got, want[from:]); diff != "" { + t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff) + } } +} + +// TestBuildExtentTree tests the extent tree building logic. +func TestBuildExtentTree(t *testing.T) { + mockExtentFile, _ := extentTreeSetUp(t, node0) opt := cmpopts.IgnoreUnexported(disklayout.ExtentIdx{}, disklayout.ExtentHeader{}) - if diff := cmp.Diff(mockInode.root, node0, opt); diff != "" { + if diff := cmp.Diff(&mockExtentFile.root, node0, opt); diff != "" { t.Errorf("extent tree mismatch (-want +got):\n%s", diff) } } -// writeTree writes the tree represented by `root` to the inode and disk passed. -func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, blkSize uint64) { +// extentTreeSetUp writes the passed extent tree to a mock disk as an extent +// tree. It also constucts a mock extent file with the same tree built in it. +// It also writes random data file data and returns it. +func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []byte) { + t.Helper() + + mockDisk := make([]byte, mockExtentBlkSize*10) + mockExtentFile := &extentFile{ + regFile: regularFile{ + inode: inode{ + diskInode: &disklayout.InodeNew{ + InodeOld: disklayout.InodeOld{ + SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root), + }, + }, + blkSize: mockExtentBlkSize, + dev: bytes.NewReader(mockDisk), + }, + }, + } + + fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize) + + if err := mockExtentFile.buildExtTree(); err != nil { + t.Fatalf("inode.buildExtTree failed: %v", err) + } + return mockExtentFile, fileData +} + +// writeTree writes the tree represented by `root` to the inode and disk. It +// also writes random file data on disk. +func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte { rootData := binary.Marshal(nil, binary.LittleEndian, root.Header) for _, ep := range root.Entries { rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry) @@ -160,26 +209,57 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, blkSize uint copy(in.diskInode.Data(), rootData) - if root.Header.Height > 0 { - for _, ep := range root.Entries { - writeTreeToDisk(disk, ep, blkSize) + var fileData []byte + for _, ep := range root.Entries { + if root.Header.Height == 0 { + fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...) + } else { + fileData = append(fileData, writeTreeToDisk(disk, ep)...) } } + return fileData } // writeTreeToDisk is the recursive step for writeTree which writes the tree -// on the disk only. -func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair, blkSize uint64) { +// on the disk only. Also writes random file data on disk. +func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte { nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header) for _, ep := range curNode.Node.Entries { nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry) } - copy(disk[curNode.Entry.PhysicalBlock()*blkSize:], nodeData) + copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData) + + var fileData []byte + for _, ep := range curNode.Node.Entries { + if curNode.Node.Header.Height == 0 { + fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...) + } else { + fileData = append(fileData, writeTreeToDisk(disk, ep)...) + } + } + return fileData +} + +// writeFileDataToExtent writes random bytes to the blocks on disk that the +// passed extent points to. +func writeFileDataToExtent(disk []byte, ex *disklayout.Extent) []byte { + phyExStartBlk := ex.PhysicalBlock() + phyExStartOff := phyExStartBlk * mockExtentBlkSize + phyExEndOff := phyExStartOff + uint64(ex.Length)*mockExtentBlkSize + rand.Read(disk[phyExStartOff:phyExEndOff]) + return disk[phyExStartOff:phyExEndOff] +} - if curNode.Node.Header.Height > 0 { - for _, ep := range curNode.Node.Entries { - writeTreeToDisk(disk, ep, blkSize) +// getNumPhyBlks returns the number of physical blocks covered under the node. +func getNumPhyBlks(node *disklayout.ExtentNode) uint32 { + var res uint32 + for _, ep := range node.Entries { + if node.Header.Height == 0 { + res += uint32(ep.Entry.(*disklayout.Extent).Length) + } else { + res += getNumPhyBlks(ep.Node) } } + return res } diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go new file mode 100644 index 000000000..a0065343b --- /dev/null +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -0,0 +1,86 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// fileDescription is embedded by ext implementations of +// vfs.FileDescriptionImpl. +type fileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + + // flags is the same as vfs.OpenOptions.Flags which are passed to + // vfs.FilesystemImpl.OpenAt. + // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2), + // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set. + // Only close(2), fstat(2), fstatfs(2) should work. + flags uint32 +} + +func (fd *fileDescription) filesystem() *filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) +} + +func (fd *fileDescription) inode() *inode { + return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode +} + +// OnClose implements vfs.FileDescriptionImpl.OnClose. +func (fd *fileDescription) OnClose() error { return nil } + +// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. +func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { + return fd.flags, nil +} + +// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. +func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { + // None of the flags settable by fcntl(F_SETFL) are supported, so this is a + // no-op. + return nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + var stat linux.Statx + fd.inode().statTo(&stat) + return stat, nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + if opts.Stat.Mask == 0 { + return nil + } + return syserror.EPERM +} + +// SetStat implements vfs.FileDescriptionImpl.StatFS. +func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + var stat linux.Statfs + fd.filesystem().statTo(&stat) + return stat, nil +} + +// Sync implements vfs.FileDescriptionImpl.Sync. +func (fd *fileDescription) Sync(ctx context.Context) error { + return nil +} diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go new file mode 100644 index 000000000..2d15e8aaf --- /dev/null +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -0,0 +1,443 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "errors" + "io" + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +var ( + // errResolveDirent indicates that the vfs.ResolvingPath.Component() does + // not exist on the dentry tree but does exist on disk. So it has to be read in + // using the in-memory dirent and added to the dentry tree. Usually indicates + // the need to lock filesystem.mu for writing. + errResolveDirent = errors.New("resolve path component using dirent") +) + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + vfsfs vfs.Filesystem + + // mu serializes changes to the Dentry tree. + mu sync.RWMutex + + // dev represents the underlying fs device. It does not require protection + // because io.ReaderAt permits concurrent read calls to it. It translates to + // the pread syscall which passes on the read request directly to the device + // driver. Device drivers are intelligent in serving multiple concurrent read + // requests in the optimal order (taking locality into consideration). + dev io.ReaderAt + + // inodeCache maps absolute inode numbers to the corresponding Inode struct. + // Inodes should be removed from this once their reference count hits 0. + // + // Protected by mu because most additions (see IterDirents) and all removals + // from this corresponds to a change in the dentry tree. + inodeCache map[uint32]*inode + + // sb represents the filesystem superblock. Immutable after initialization. + sb disklayout.SuperBlock + + // bgs represents all the block group descriptors for the filesystem. + // Immutable after initialization. + bgs []disklayout.BlockGroup +} + +// Compiles only if filesystem implements vfs.FilesystemImpl. +var _ vfs.FilesystemImpl = (*filesystem)(nil) + +// stepLocked resolves rp.Component() in parent directory vfsd. The write +// parameter passed tells if the caller has acquired filesystem.mu for writing +// or not. If set to true, an existing inode on disk can be added to the dentry +// tree if not present already. +// +// stepLocked is loosely analogous to fs/namei.c:walk_component(). +// +// Preconditions: +// - filesystem.mu must be locked (for writing if write param is true). +// - !rp.Done(). +// - inode == vfsd.Impl().(*Dentry).inode. +func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { + if !inode.isDir() { + return nil, nil, syserror.ENOTDIR + } + if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, nil, err + } + + for { + nextVFSD, err := rp.ResolveComponent(vfsd) + if err != nil { + return nil, nil, err + } + if nextVFSD == nil { + // Since the Dentry tree is not the sole source of truth for extfs, if it's + // not in the Dentry tree, it might need to be pulled from disk. + childDirent, ok := inode.impl.(*directory).childMap[rp.Component()] + if !ok { + // The underlying inode does not exist on disk. + return nil, nil, syserror.ENOENT + } + + if !write { + // filesystem.mu must be held for writing to add to the dentry tree. + return nil, nil, errResolveDirent + } + + // Create and add the component's dirent to the dentry tree. + fs := rp.Mount().Filesystem().Impl().(*filesystem) + childInode, err := fs.getOrCreateInodeLocked(childDirent.diskDirent.Inode()) + if err != nil { + return nil, nil, err + } + // incRef because this is being added to the dentry tree. + childInode.incRef() + child := newDentry(childInode) + vfsd.InsertChild(&child.vfsd, rp.Component()) + + // Continue as usual now that nextVFSD is not nil. + nextVFSD = &child.vfsd + } + nextInode := nextVFSD.Impl().(*dentry).inode + if nextInode.isSymlink() && rp.ShouldFollowSymlink() { + if err := rp.HandleSymlink(inode.impl.(*symlink).target); err != nil { + return nil, nil, err + } + continue + } + rp.Advance() + return nextVFSD, nextInode, nil + } +} + +// walkLocked resolves rp to an existing file. The write parameter +// passed tells if the caller has acquired filesystem.mu for writing or not. +// If set to true, additions can be made to the dentry tree while walking. +// If errResolveDirent is returned, the walk needs to be continued with an +// upgraded filesystem.mu. +// +// walkLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). +// +// Preconditions: +// - filesystem.mu must be locked (for writing if write param is true). +func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { + vfsd := rp.Start() + inode := vfsd.Impl().(*dentry).inode + for !rp.Done() { + var err error + vfsd, inode, err = stepLocked(rp, vfsd, inode, write) + if err != nil { + return nil, nil, err + } + } + if rp.MustBeDir() && !inode.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, inode, nil +} + +// walkParentLocked resolves all but the last path component of rp to an +// existing directory. It does not check that the returned directory is +// searchable by the provider of rp. The write parameter passed tells if the +// caller has acquired filesystem.mu for writing or not. If set to true, +// additions can be made to the dentry tree while walking. +// If errResolveDirent is returned, the walk needs to be continued with an +// upgraded filesystem.mu. +// +// walkParentLocked is loosely analogous to Linux's fs/namei.c:path_parentat(). +// +// Preconditions: +// - filesystem.mu must be locked (for writing if write param is true). +// - !rp.Done(). +func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { + vfsd := rp.Start() + inode := vfsd.Impl().(*dentry).inode + for !rp.Final() { + var err error + vfsd, inode, err = stepLocked(rp, vfsd, inode, write) + if err != nil { + return nil, nil, err + } + } + if !inode.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, inode, nil +} + +// walk resolves rp to an existing file. If parent is set to true, it resolves +// the rp till the parent of the last component which should be an existing +// directory. If parent is false then resolves rp entirely. Attemps to resolve +// the path as far as it can with a read lock and upgrades the lock if needed. +func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) { + var ( + vfsd *vfs.Dentry + inode *inode + err error + ) + + // Try walking with the hopes that all dentries have already been pulled out + // of disk. This reduces congestion (allows concurrent walks). + fs.mu.RLock() + if parent { + vfsd, inode, err = walkParentLocked(rp, false) + } else { + vfsd, inode, err = walkLocked(rp, false) + } + fs.mu.RUnlock() + + if err == errResolveDirent { + // Upgrade lock and continue walking. Lock upgrading in the middle of the + // walk is fine as this is a read only filesystem. + fs.mu.Lock() + if parent { + vfsd, inode, err = walkParentLocked(rp, true) + } else { + vfsd, inode, err = walkLocked(rp, true) + } + fs.mu.Unlock() + } + + return vfsd, inode, err +} + +// getOrCreateInodeLocked gets the inode corresponding to the inode number passed in. +// It creates a new one with the given inode number if one does not exist. +// The caller must increment the ref count if adding this to the dentry tree. +// +// Precondition: must be holding fs.mu for writing. +func (fs *filesystem) getOrCreateInodeLocked(inodeNum uint32) (*inode, error) { + if in, ok := fs.inodeCache[inodeNum]; ok { + return in, nil + } + + in, err := newInode(fs, inodeNum) + if err != nil { + return nil, err + } + + fs.inodeCache[inodeNum] = in + return in, nil +} + +// statTo writes the statfs fields to the output parameter. +func (fs *filesystem) statTo(stat *linux.Statfs) { + stat.Type = uint64(fs.sb.Magic()) + stat.BlockSize = int64(fs.sb.BlockSize()) + stat.Blocks = fs.sb.BlocksCount() + stat.BlocksFree = fs.sb.FreeBlocksCount() + stat.BlocksAvailable = fs.sb.FreeBlocksCount() + stat.Files = uint64(fs.sb.InodesCount()) + stat.FilesFree = uint64(fs.sb.FreeInodesCount()) + stat.NameLength = disklayout.MaxFileName + stat.FragmentSize = int64(fs.sb.BlockSize()) + // TODO(b/134676337): Set Statfs.Flags and Statfs.FSID. +} + +// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. +func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { + vfsd, inode, err := fs.walk(rp, false) + if err != nil { + return nil, err + } + + if opts.CheckSearchable { + if !inode.isDir() { + return nil, syserror.ENOTDIR + } + if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + } + + inode.incRef() + return vfsd, nil +} + +// OpenAt implements vfs.FilesystemImpl.OpenAt. +func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + vfsd, inode, err := fs.walk(rp, false) + if err != nil { + return nil, err + } + + // EROFS is returned if write access is needed. + if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 { + return nil, syserror.EROFS + } + return inode.open(rp, vfsd, opts.Flags) +} + +// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. +func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { + _, inode, err := fs.walk(rp, false) + if err != nil { + return "", err + } + symlink, ok := inode.impl.(*symlink) + if !ok { + return "", syserror.EINVAL + } + return symlink.target, nil +} + +// StatAt implements vfs.FilesystemImpl.StatAt. +func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { + _, inode, err := fs.walk(rp, false) + if err != nil { + return linux.Statx{}, err + } + var stat linux.Statx + inode.statTo(&stat) + return stat, nil +} + +// StatFSAt implements vfs.FilesystemImpl.StatFSAt. +func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { + if _, _, err := fs.walk(rp, false); err != nil { + return linux.Statfs{}, err + } + + var stat linux.Statfs + fs.statTo(&stat) + return stat, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() {} + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *filesystem) Sync(ctx context.Context) error { + // This is a readonly filesystem for now. + return nil +} + +// The vfs.FilesystemImpl functions below return EROFS because their respective +// man pages say that EROFS must be returned if the path resolves to a file on +// this read-only filesystem. + +// LinkAt implements vfs.FilesystemImpl.LinkAt. +func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { + if rp.Done() { + return syserror.EEXIST + } + + if _, _, err := fs.walk(rp, true); err != nil { + return err + } + + return syserror.EROFS +} + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + if rp.Done() { + return syserror.EEXIST + } + + if _, _, err := fs.walk(rp, true); err != nil { + return err + } + + return syserror.EROFS +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + if rp.Done() { + return syserror.EEXIST + } + + _, _, err := fs.walk(rp, true) + if err != nil { + return err + } + + return syserror.EROFS +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { + if rp.Done() { + return syserror.ENOENT + } + + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + + return syserror.EROFS +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + _, inode, err := fs.walk(rp, false) + if err != nil { + return err + } + + if !inode.isDir() { + return syserror.ENOTDIR + } + + return syserror.EROFS +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + + return syserror.EROFS +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + if rp.Done() { + return syserror.EEXIST + } + + _, _, err := fs.walk(rp, true) + if err != nil { + return err + } + + return syserror.EROFS +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + _, inode, err := fs.walk(rp, false) + if err != nil { + return err + } + + if inode.isDir() { + return syserror.EISDIR + } + + return syserror.EROFS +} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go new file mode 100644 index 000000000..e6c847a71 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -0,0 +1,219 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// inode represents an ext inode. +// +// inode uses the same inheritance pattern that pkg/sentry/vfs structures use. +// This has been done to increase memory locality. +// +// Implementations: +// inode -- +// |-- dir +// |-- symlink +// |-- regular-- +// |-- extent file +// |-- block map file +type inode struct { + // refs is a reference count. refs is accessed using atomic memory operations. + refs int64 + + // inodeNum is the inode number of this inode on disk. This is used to + // identify inodes within the ext filesystem. + inodeNum uint32 + + // dev represents the underlying device. Same as filesystem.dev. + dev io.ReaderAt + + // blkSize is the fs data block size. Same as filesystem.sb.BlockSize(). + blkSize uint64 + + // diskInode gives us access to the inode struct on disk. Immutable. + diskInode disklayout.Inode + + // This is immutable. The first field of the implementations must have inode + // as the first field to ensure temporality. + impl interface{} +} + +// incRef increments the inode ref count. +func (in *inode) incRef() { + atomic.AddInt64(&in.refs, 1) +} + +// tryIncRef tries to increment the ref count. Returns true if successful. +func (in *inode) tryIncRef() bool { + for { + refs := atomic.LoadInt64(&in.refs) + if refs == 0 { + return false + } + if atomic.CompareAndSwapInt64(&in.refs, refs, refs+1) { + return true + } + } +} + +// decRef decrements the inode ref count and releases the inode resources if +// the ref count hits 0. +// +// Precondition: Must have locked fs.mu. +func (in *inode) decRef(fs *filesystem) { + if refs := atomic.AddInt64(&in.refs, -1); refs == 0 { + delete(fs.inodeCache, in.inodeNum) + } else if refs < 0 { + panic("ext.inode.decRef() called without holding a reference") + } +} + +// newInode is the inode constructor. Reads the inode off disk. Identifies +// inodes based on the absolute inode number on disk. +func newInode(fs *filesystem, inodeNum uint32) (*inode, error) { + if inodeNum == 0 { + panic("inode number 0 on ext filesystems is not possible") + } + + inodeRecordSize := fs.sb.InodeSize() + var diskInode disklayout.Inode + if inodeRecordSize == disklayout.OldInodeSize { + diskInode = &disklayout.InodeOld{} + } else { + diskInode = &disklayout.InodeNew{} + } + + // Calculate where the inode is actually placed. + inodesPerGrp := fs.sb.InodesPerGroup() + blkSize := fs.sb.BlockSize() + inodeTableOff := fs.bgs[getBGNum(inodeNum, inodesPerGrp)].InodeTable() * blkSize + inodeOff := inodeTableOff + uint64(uint32(inodeRecordSize)*getBGOff(inodeNum, inodesPerGrp)) + + if err := readFromDisk(fs.dev, int64(inodeOff), diskInode); err != nil { + return nil, err + } + + // Build the inode based on its type. + inode := inode{ + inodeNum: inodeNum, + dev: fs.dev, + blkSize: blkSize, + diskInode: diskInode, + } + + switch diskInode.Mode().FileType() { + case linux.ModeSymlink: + f, err := newSymlink(inode) + if err != nil { + return nil, err + } + return &f.inode, nil + case linux.ModeRegular: + f, err := newRegularFile(inode) + if err != nil { + return nil, err + } + return &f.inode, nil + case linux.ModeDirectory: + f, err := newDirectroy(inode, fs.sb.IncompatibleFeatures().DirentFileType) + if err != nil { + return nil, err + } + return &f.inode, nil + default: + // TODO(b/134676337): Return appropriate errors for sockets, pipes and devices. + return nil, syserror.EINVAL + } +} + +// open creates and returns a file description for the dentry passed in. +func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + ats := vfs.AccessTypesForOpenFlags(flags) + if err := in.checkPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + switch in.impl.(type) { + case *regularFile: + var fd regularFileFD + fd.flags = flags + fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + return &fd.vfsfd, nil + case *directory: + // Can't open directories writably. This check is not necessary for a read + // only filesystem but will be required when write is implemented. + if ats&vfs.MayWrite != 0 { + return nil, syserror.EISDIR + } + var fd directoryFD + fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + fd.flags = flags + return &fd.vfsfd, nil + case *symlink: + if flags&linux.O_PATH == 0 { + // Can't open symlinks without O_PATH. + return nil, syserror.ELOOP + } + var fd symlinkFD + fd.flags = flags + fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + return &fd.vfsfd, nil + default: + panic(fmt.Sprintf("unknown inode type: %T", in.impl)) + } +} + +func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID()) +} + +// statTo writes the statx fields to the output parameter. +func (in *inode) statTo(stat *linux.Statx) { + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | + linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | + linux.STATX_ATIME | linux.STATX_CTIME | linux.STATX_MTIME + stat.Blksize = uint32(in.blkSize) + stat.Mode = uint16(in.diskInode.Mode()) + stat.Nlink = uint32(in.diskInode.LinksCount()) + stat.UID = uint32(in.diskInode.UID()) + stat.GID = uint32(in.diskInode.GID()) + stat.Ino = uint64(in.inodeNum) + stat.Size = in.diskInode.Size() + stat.Atime = in.diskInode.AccessTime().StatxTimestamp() + stat.Ctime = in.diskInode.ChangeTime().StatxTimestamp() + stat.Mtime = in.diskInode.ModificationTime().StatxTimestamp() + // TODO(b/134676337): Set stat.Blocks which is the number of 512 byte blocks + // (including metadata blocks) required to represent this file. +} + +// getBGNum returns the block group number that a given inode belongs to. +func getBGNum(inodeNum uint32, inodesPerGrp uint32) uint32 { + return (inodeNum - 1) / inodesPerGrp +} + +// getBGOff returns the offset at which the given inode lives in the block +// group's inode table, i.e. the index of the inode in the inode table. +func getBGOff(inodeNum uint32, inodesPerGrp uint32) uint32 { + return (inodeNum - 1) % inodesPerGrp +} diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go new file mode 100644 index 000000000..ffc76ba5b --- /dev/null +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -0,0 +1,159 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "io" + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// regularFile represents a regular file's inode. This too follows the +// inheritance pattern prevelant in the vfs layer described in +// pkg/sentry/vfs/README.md. +type regularFile struct { + inode inode + + // This is immutable. The first field of fileReader implementations must be + // regularFile to ensure temporality. + // io.ReaderAt is more strict than io.Reader in the sense that a partial read + // is always accompanied by an error. If a read spans past the end of file, a + // partial read (within file range) is done and io.EOF is returned. + impl io.ReaderAt +} + +// newRegularFile is the regularFile constructor. It figures out what kind of +// file this is and initializes the fileReader. +func newRegularFile(inode inode) (*regularFile, error) { + regFile := regularFile{ + inode: inode, + } + + inodeFlags := inode.diskInode.Flags() + + if inodeFlags.Extents { + file, err := newExtentFile(regFile) + if err != nil { + return nil, err + } + + file.regFile.inode.impl = &file.regFile + return &file.regFile, nil + } + + file, err := newBlockMapFile(regFile) + if err != nil { + return nil, err + } + file.regFile.inode.impl = &file.regFile + return &file.regFile, nil +} + +func (in *inode) isRegular() bool { + _, ok := in.impl.(*regularFile) + return ok +} + +// directoryFD represents a directory file description. It implements +// vfs.FileDescriptionImpl. +type regularFileFD struct { + fileDescription + + // off is the file offset. off is accessed using atomic memory operations. + off int64 + + // offMu serializes operations that may mutate off. + offMu sync.Mutex +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *regularFileFD) Release() {} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + safeReader := safemem.FromIOReaderAt{ + ReaderAt: fd.inode().impl.(*regularFile).impl, + Offset: offset, + } + + // Copies data from disk directly into usermem without any intermediate + // allocations (if dst is converted into BlockSeq such that it does not need + // safe copying). + return dst.CopyOutFrom(ctx, safeReader) +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.offMu.Lock() + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + // write(2) specifies that EBADF must be returned if the fd is not open for + // writing. + return 0, syserror.EBADF +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + n, err := fd.PWrite(ctx, src, fd.off, opts) + fd.offMu.Lock() + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *regularFileFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + return syserror.ENOTDIR +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.offMu.Lock() + defer fd.offMu.Unlock() + switch whence { + case linux.SEEK_SET: + // Use offset as specified. + case linux.SEEK_CUR: + offset += fd.off + case linux.SEEK_END: + offset += int64(fd.inode().diskInode.Size()) + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { + // TODO(b/134676337): Implement mmap(2). + return syserror.ENODEV +} diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go new file mode 100644 index 000000000..e06548a98 --- /dev/null +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -0,0 +1,111 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// symlink represents a symlink inode. +type symlink struct { + inode inode + target string // immutable +} + +// newSymlink is the symlink constructor. It reads out the symlink target from +// the inode (however it might have been stored). +func newSymlink(inode inode) (*symlink, error) { + var file *symlink + var link []byte + + // If the symlink target is lesser than 60 bytes, its stores in inode.Data(). + // Otherwise either extents or block maps will be used to store the link. + size := inode.diskInode.Size() + if size < 60 { + link = inode.diskInode.Data()[:size] + } else { + // Create a regular file out of this inode and read out the target. + regFile, err := newRegularFile(inode) + if err != nil { + return nil, err + } + + link = make([]byte, size) + if n, err := regFile.impl.ReadAt(link, 0); uint64(n) < size { + return nil, err + } + } + + file = &symlink{inode: inode, target: string(link)} + file.inode.impl = file + return file, nil +} + +func (in *inode) isSymlink() bool { + _, ok := in.impl.(*symlink) + return ok +} + +// symlinkFD represents a symlink file description and implements implements +// vfs.FileDescriptionImpl. which may only be used if open options contains +// O_PATH. For this reason most of the functions return EBADF. +type symlinkFD struct { + fileDescription +} + +// Compiles only if symlinkFD implements vfs.FileDescriptionImpl. +var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil) + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *symlinkFD) Release() {} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *symlinkFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EBADF +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *symlinkFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *symlinkFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *symlinkFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + return syserror.ENOTDIR +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return 0, syserror.EBADF +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { + return syserror.EBADF +} diff --git a/pkg/sentry/fs/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go index 3472c5fa8..d8b728f8c 100644 --- a/pkg/sentry/fs/ext/utils.go +++ b/pkg/sentry/fsimpl/ext/utils.go @@ -15,38 +15,30 @@ package ext import ( - "encoding/binary" "io" - "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // readFromDisk performs a binary read from disk into the given struct from // the absolute offset provided. -// -// All disk reads should use this helper so we avoid reading from stale -// previously used offsets. This function forces the offset parameter. -// -// Precondition: Must hold the mutex of the filesystem containing dev. -func readFromDisk(dev io.ReadSeeker, abOff int64, v interface{}) error { - if _, err := dev.Seek(abOff, io.SeekStart); err != nil { - return syserror.EIO - } - - if err := binary.Read(dev, binary.LittleEndian, v); err != nil { +func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error { + n := binary.Size(v) + buf := make([]byte, n) + if read, _ := dev.ReadAt(buf, abOff); read < int(n) { return syserror.EIO } + binary.Unmarshal(buf, binary.LittleEndian, v) return nil } // readSuperBlock reads the SuperBlock from block group 0 in the underlying // device. There are three versions of the superblock. This function identifies // and returns the correct version. -// -// Precondition: Must hold the mutex of the filesystem containing dev. -func readSuperBlock(dev io.ReadSeeker) (disklayout.SuperBlock, error) { +func readSuperBlock(dev io.ReaderAt) (disklayout.SuperBlock, error) { var sb disklayout.SuperBlock = &disklayout.SuperBlockOld{} if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil { return nil, err @@ -76,19 +68,12 @@ func blockGroupsCount(sb disklayout.SuperBlock) uint64 { blocksPerGroup := uint64(sb.BlocksPerGroup()) // Round up the result. float64 can compromise precision so do it manually. - bgCount := blocksCount / blocksPerGroup - if blocksCount%blocksPerGroup != 0 { - bgCount++ - } - - return bgCount + return (blocksCount + blocksPerGroup - 1) / blocksPerGroup } // readBlockGroups reads the block group descriptor table from block group 0 in // the underlying device. -// -// Precondition: Must hold the mutex of the filesystem containing dev. -func readBlockGroups(dev io.ReadSeeker, sb disklayout.SuperBlock) ([]disklayout.BlockGroup, error) { +func readBlockGroups(dev io.ReaderAt, sb disklayout.SuperBlock) ([]disklayout.BlockGroup, error) { bgCount := blockGroupsCount(sb) bgdSize := uint64(sb.BgDescSize()) is64Bit := sb.IncompatibleFeatures().Is64Bit diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index d5d4f68df..d2450e810 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -11,8 +11,8 @@ go_template_instance( prefix = "dentry", template = "//pkg/ilist:generic_list", types = { - "Element": "*Dentry", - "Linker": "*Dentry", + "Element": "*dentry", + "Linker": "*dentry", }, ) diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index b0c3ea39a..c52dc781c 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -23,23 +23,23 @@ import ( ) type directory struct { - inode Inode + inode inode // childList is a list containing (1) child Dentries and (2) fake Dentries // (with inode == nil) that represent the iteration position of // directoryFDs. childList is used to support directoryFD.IterDirents() - // efficiently. childList is protected by Filesystem.mu. + // efficiently. childList is protected by filesystem.mu. childList dentryList } -func (fs *Filesystem) newDirectory(creds *auth.Credentials, mode uint16) *Inode { +func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode { dir := &directory{} dir.inode.init(dir, fs, creds, mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root return &dir.inode } -func (i *Inode) isDir() bool { +func (i *inode) isDir() bool { _, ok := i.impl.(*directory) return ok } @@ -48,8 +48,8 @@ type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl - // Protected by Filesystem.mu. - iter *Dentry + // Protected by filesystem.mu. + iter *dentry off int64 } @@ -68,7 +68,7 @@ func (fd *directoryFD) Release() { // IterDirents implements vfs.FileDescriptionImpl.IterDirents. func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { fs := fd.filesystem() - d := fd.vfsfd.VirtualDentry().Dentry() + vfsd := fd.vfsfd.VirtualDentry().Dentry() fs.mu.Lock() defer fs.mu.Unlock() @@ -77,7 +77,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if !cb.Handle(vfs.Dirent{ Name: ".", Type: linux.DT_DIR, - Ino: d.Impl().(*Dentry).inode.ino, + Ino: vfsd.Impl().(*dentry).inode.ino, Off: 0, }) { return nil @@ -85,7 +85,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.off++ } if fd.off == 1 { - parentInode := d.ParentOrSelf().Impl().(*Dentry).inode + parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode if !cb.Handle(vfs.Dirent{ Name: "..", Type: parentInode.direntType(), @@ -97,12 +97,12 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.off++ } - dir := d.Impl().(*Dentry).inode.impl.(*directory) - var child *Dentry + dir := vfsd.Impl().(*dentry).inode.impl.(*directory) + var child *dentry if fd.iter == nil { // Start iteration at the beginning of dir. child = dir.childList.Front() - fd.iter = &Dentry{} + fd.iter = &dentry{} } else { // Continue iteration from where we left off. child = fd.iter.Next() @@ -130,32 +130,41 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { - if whence != linux.SEEK_SET { - // TODO: Linux also allows SEEK_CUR. + fs := fd.filesystem() + fs.mu.Lock() + defer fs.mu.Unlock() + + switch whence { + case linux.SEEK_SET: + // Use offset as given. + case linux.SEEK_CUR: + offset += fd.off + default: return 0, syserror.EINVAL } if offset < 0 { return 0, syserror.EINVAL } + // If the offset isn't changing (e.g. due to lseek(0, SEEK_CUR)), don't + // seek even if doing so might reposition the iterator due to concurrent + // mutation of the directory. Compare fs/libfs.c:dcache_dir_lseek(). + if fd.off == offset { + return offset, nil + } + fd.off = offset // Compensate for "." and "..". - var remChildren int64 - if offset < 2 { - remChildren = 0 - } else { + remChildren := int64(0) + if offset >= 2 { remChildren = offset - 2 } - fs := fd.filesystem() dir := fd.inode().impl.(*directory) - fs.mu.Lock() - defer fs.mu.Unlock() - // Ensure that fd.iter exists and is not linked into dir.childList. if fd.iter == nil { - fd.iter = &Dentry{} + fd.iter = &dentry{} } else { dir.childList.Remove(fd.iter) } diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go index 4d989eeaf..f79e2d9c8 100644 --- a/pkg/sentry/fsimpl/memfs/filesystem.go +++ b/pkg/sentry/fsimpl/memfs/filesystem.go @@ -28,9 +28,9 @@ import ( // // stepLocked is loosely analogous to fs/namei.c:walk_component(). // -// Preconditions: Filesystem.mu must be locked. !rp.Done(). inode == -// vfsd.Impl().(*Dentry).inode. -func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *Inode) (*vfs.Dentry, *Inode, error) { +// Preconditions: filesystem.mu must be locked. !rp.Done(). inode == +// vfsd.Impl().(*dentry).inode. +func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode) (*vfs.Dentry, *inode, error) { if !inode.isDir() { return nil, nil, syserror.ENOTDIR } @@ -47,7 +47,7 @@ afterSymlink: // not in the Dentry tree, it doesn't exist. return nil, nil, syserror.ENOENT } - nextInode := nextVFSD.Impl().(*Dentry).inode + nextInode := nextVFSD.Impl().(*dentry).inode if symlink, ok := nextInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { // TODO: symlink traversals update access time if err := rp.HandleSymlink(symlink.target); err != nil { @@ -64,10 +64,10 @@ afterSymlink: // walkExistingLocked is loosely analogous to Linux's // fs/namei.c:path_lookupat(). // -// Preconditions: Filesystem.mu must be locked. -func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) { +// Preconditions: filesystem.mu must be locked. +func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() - inode := vfsd.Impl().(*Dentry).inode + inode := vfsd.Impl().(*dentry).inode for !rp.Done() { var err error vfsd, inode, err = stepLocked(rp, vfsd, inode) @@ -88,10 +88,10 @@ func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) { // walkParentDirLocked is loosely analogous to Linux's // fs/namei.c:path_parentat(). // -// Preconditions: Filesystem.mu must be locked. !rp.Done(). -func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) { +// Preconditions: filesystem.mu must be locked. !rp.Done(). +func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() - inode := vfsd.Impl().(*Dentry).inode + inode := vfsd.Impl().(*dentry).inode for !rp.Final() { var err error vfsd, inode, err = stepLocked(rp, vfsd, inode) @@ -108,9 +108,9 @@ func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) { // checkCreateLocked checks that a file named rp.Component() may be created in // directory parentVFSD, then returns rp.Component(). // -// Preconditions: Filesystem.mu must be locked. parentInode == -// parentVFSD.Impl().(*Dentry).inode. parentInode.isDir() == true. -func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *Inode) (string, error) { +// Preconditions: filesystem.mu must be locked. parentInode == +// parentVFSD.Impl().(*dentry).inode. parentInode.isDir() == true. +func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *inode) (string, error) { if err := parentInode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { return "", err } @@ -144,7 +144,7 @@ func checkDeleteLocked(vfsd *vfs.Dentry) error { } // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. -func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { +func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() defer fs.mu.RUnlock() vfsd, inode, err := walkExistingLocked(rp) @@ -164,7 +164,7 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op } // LinkAt implements vfs.FilesystemImpl.LinkAt. -func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { +func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { if rp.Done() { return syserror.EEXIST } @@ -185,7 +185,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return err } defer rp.Mount().EndWrite() - d := vd.Dentry().Impl().(*Dentry) + d := vd.Dentry().Impl().(*dentry) if d.inode.isDir() { return syserror.EPERM } @@ -197,7 +197,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. } // MkdirAt implements vfs.FilesystemImpl.MkdirAt. -func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { +func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { if rp.Done() { return syserror.EEXIST } @@ -223,7 +223,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } // MknodAt implements vfs.FilesystemImpl.MknodAt. -func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { +func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { if rp.Done() { return syserror.EEXIST } @@ -246,7 +246,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } // OpenAt implements vfs.FilesystemImpl.OpenAt. -func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { // Filter out flags that are not supported by memfs. O_DIRECTORY and // O_NOFOLLOW have no effect here (they're handled by VFS by setting // appropriate bits in rp), but are returned by @@ -265,11 +265,10 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf mustCreate := opts.Flags&linux.O_EXCL != 0 vfsd := rp.Start() - inode := vfsd.Impl().(*Dentry).inode + inode := vfsd.Impl().(*dentry).inode fs.mu.Lock() defer fs.mu.Unlock() if rp.Done() { - // FIXME: ??? if rp.MustBeDir() { return nil, syserror.EISDIR } @@ -327,7 +326,7 @@ afterTrailingSymlink: if mustCreate { return nil, syserror.EEXIST } - childInode := childVFSD.Impl().(*Dentry).inode + childInode := childVFSD.Impl().(*dentry).inode if symlink, ok := childInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { // TODO: symlink traversals update access time if err := rp.HandleSymlink(symlink.target); err != nil { @@ -340,7 +339,7 @@ afterTrailingSymlink: return childInode.open(rp, childVFSD, opts.Flags, false) } -func (i *Inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { +func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(flags) if !afterCreate { if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil { @@ -385,7 +384,7 @@ func (i *Inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte } // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. -func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { +func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { fs.mu.RLock() _, inode, err := walkExistingLocked(rp) fs.mu.RUnlock() @@ -400,9 +399,8 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st } // RenameAt implements vfs.FilesystemImpl.RenameAt. -func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { +func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { if rp.Done() { - // FIXME return syserror.ENOENT } fs.mu.Lock() @@ -424,7 +422,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vf } // RmdirAt implements vfs.FilesystemImpl.RmdirAt. -func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { +func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() vfsd, inode, err := walkExistingLocked(rp) @@ -447,12 +445,14 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { return err } + // Remove from parent directory's childList. + vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry)) inode.decRef() return nil } // SetStatAt implements vfs.FilesystemImpl.SetStatAt. -func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { +func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() _, _, err := walkExistingLocked(rp) fs.mu.RUnlock() @@ -462,12 +462,12 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts if opts.Stat.Mask == 0 { return nil } - // TODO: implement Inode.setStat + // TODO: implement inode.setStat return syserror.EPERM } // StatAt implements vfs.FilesystemImpl.StatAt. -func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { +func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { fs.mu.RLock() _, inode, err := walkExistingLocked(rp) fs.mu.RUnlock() @@ -480,7 +480,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. -func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { +func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() _, _, err := walkExistingLocked(rp) fs.mu.RUnlock() @@ -492,7 +492,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. -func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { +func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { if rp.Done() { return syserror.EEXIST } @@ -517,7 +517,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ } // UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. -func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { +func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() vfsd, inode, err := walkExistingLocked(rp) @@ -537,6 +537,8 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { return err } + // Remove from parent directory's childList. + vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry)) inode.decLinksLocked() return nil } diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index f381e1a88..45cd42b3e 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -21,10 +21,10 @@ // // Lock order: // -// Filesystem.mu +// filesystem.mu // regularFileFD.offMu // regularFile.mu -// Inode.mu +// inode.mu package memfs import ( @@ -42,8 +42,8 @@ import ( // FilesystemType implements vfs.FilesystemType. type FilesystemType struct{} -// Filesystem implements vfs.FilesystemImpl. -type Filesystem struct { +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { vfsfs vfs.Filesystem // mu serializes changes to the Dentry tree. @@ -54,44 +54,44 @@ type Filesystem struct { // NewFilesystem implements vfs.FilesystemType.NewFilesystem. func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - var fs Filesystem + var fs filesystem fs.vfsfs.Init(&fs) root := fs.newDentry(fs.newDirectory(creds, 01777)) return &fs.vfsfs, &root.vfsd, nil } // Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release() { +func (fs *filesystem) Release() { } // Sync implements vfs.FilesystemImpl.Sync. -func (fs *Filesystem) Sync(ctx context.Context) error { +func (fs *filesystem) Sync(ctx context.Context) error { // All filesystem state is in-memory. return nil } -// Dentry implements vfs.DentryImpl. -type Dentry struct { +// dentry implements vfs.DentryImpl. +type dentry struct { vfsd vfs.Dentry - // inode is the inode represented by this Dentry. Multiple Dentries may - // share a single non-directory Inode (with hard links). inode is + // inode is the inode represented by this dentry. Multiple Dentries may + // share a single non-directory inode (with hard links). inode is // immutable. - inode *Inode + inode *inode - // memfs doesn't count references on Dentries; because the Dentry tree is + // memfs doesn't count references on dentries; because the dentry tree is // the sole source of truth, it is by definition always consistent with the - // state of the filesystem. However, it does count references on Inodes, - // because Inode resources are released when all references are dropped. + // state of the filesystem. However, it does count references on inodes, + // because inode resources are released when all references are dropped. // (memfs doesn't really have resources to release, but we implement // reference counting because tmpfs regular files will.) - // dentryEntry (ugh) links Dentries into their parent directory.childList. + // dentryEntry (ugh) links dentries into their parent directory.childList. dentryEntry } -func (fs *Filesystem) newDentry(inode *Inode) *Dentry { - d := &Dentry{ +func (fs *filesystem) newDentry(inode *inode) *dentry { + d := &dentry{ inode: inode, } d.vfsd.Init(d) @@ -99,37 +99,37 @@ func (fs *Filesystem) newDentry(inode *Inode) *Dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *Dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *Dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef(vfsfs *vfs.Filesystem) { +func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { d.inode.decRef() } -// Inode represents a filesystem object. -type Inode struct { +// inode represents a filesystem object. +type inode struct { // refs is a reference count. refs is accessed using atomic memory // operations. // - // A reference is held on all Inodes that are reachable in the filesystem + // A reference is held on all inodes that are reachable in the filesystem // tree. For non-directories (which may have multiple hard links), this // means that a reference is dropped when nlink reaches 0. For directories, // nlink never reaches 0 due to the "." entry; instead, - // Filesystem.RmdirAt() drops the reference. + // filesystem.RmdirAt() drops the reference. refs int64 // Inode metadata; protected by mu and accessed using atomic memory // operations unless otherwise specified. mu sync.RWMutex mode uint32 // excluding file type bits, which are based on impl - nlink uint32 // protected by Filesystem.mu instead of Inode.mu + nlink uint32 // protected by filesystem.mu instead of inode.mu uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic gid uint32 // auth.KGID, but ... ino uint64 // immutable @@ -137,7 +137,7 @@ type Inode struct { impl interface{} // immutable } -func (i *Inode) init(impl interface{}, fs *Filesystem, creds *auth.Credentials, mode uint16) { +func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) { i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) @@ -147,29 +147,29 @@ func (i *Inode) init(impl interface{}, fs *Filesystem, creds *auth.Credentials, i.impl = impl } -// Preconditions: Filesystem.mu must be locked for writing. -func (i *Inode) incLinksLocked() { +// Preconditions: filesystem.mu must be locked for writing. +func (i *inode) incLinksLocked() { if atomic.AddUint32(&i.nlink, 1) <= 1 { - panic("memfs.Inode.incLinksLocked() called with no existing links") + panic("memfs.inode.incLinksLocked() called with no existing links") } } -// Preconditions: Filesystem.mu must be locked for writing. -func (i *Inode) decLinksLocked() { +// Preconditions: filesystem.mu must be locked for writing. +func (i *inode) decLinksLocked() { if nlink := atomic.AddUint32(&i.nlink, ^uint32(0)); nlink == 0 { i.decRef() } else if nlink == ^uint32(0) { // negative overflow - panic("memfs.Inode.decLinksLocked() called with no existing links") + panic("memfs.inode.decLinksLocked() called with no existing links") } } -func (i *Inode) incRef() { +func (i *inode) incRef() { if atomic.AddInt64(&i.refs, 1) <= 1 { - panic("memfs.Inode.incRef() called without holding a reference") + panic("memfs.inode.incRef() called without holding a reference") } } -func (i *Inode) tryIncRef() bool { +func (i *inode) tryIncRef() bool { for { refs := atomic.LoadInt64(&i.refs) if refs == 0 { @@ -181,7 +181,7 @@ func (i *Inode) tryIncRef() bool { } } -func (i *Inode) decRef() { +func (i *inode) decRef() { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { // This is unnecessary; it's mostly to simulate what tmpfs would do. if regfile, ok := i.impl.(*regularFile); ok { @@ -191,18 +191,18 @@ func (i *Inode) decRef() { regfile.mu.Unlock() } } else if refs < 0 { - panic("memfs.Inode.decRef() called without holding a reference") + panic("memfs.inode.decRef() called without holding a reference") } } -func (i *Inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error { +func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error { return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))) } // Go won't inline this function, and returning linux.Statx (which is quite // big) means spending a lot of time in runtime.duffcopy(), so instead it's an // output parameter. -func (i *Inode) statTo(stat *linux.Statx) { +func (i *inode) statTo(stat *linux.Statx) { stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO stat.Blksize = 1 // usermem.PageSize in tmpfs stat.Nlink = atomic.LoadUint32(&i.nlink) @@ -241,7 +241,7 @@ func allocatedBlocksForSize(size uint64) uint64 { return (size + 511) / 512 } -func (i *Inode) direntType() uint8 { +func (i *inode) direntType() uint8 { switch i.impl.(type) { case *regularFile: return linux.DT_REG @@ -258,16 +258,17 @@ func (i *Inode) direntType() uint8 { // vfs.FileDescriptionImpl. type fileDescription struct { vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl flags uint32 // status flags; immutable } -func (fd *fileDescription) filesystem() *Filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*Filesystem) +func (fd *fileDescription) filesystem() *filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) } -func (fd *fileDescription) inode() *Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode +func (fd *fileDescription) inode() *inode { + return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode } // StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. @@ -294,6 +295,6 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) if opts.Stat.Mask == 0 { return nil } - // TODO: implement Inode.setStat + // TODO: implement inode.setStat return syserror.EPERM } diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go index 4a3603cc8..55f869798 100644 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ b/pkg/sentry/fsimpl/memfs/regular_file.go @@ -28,16 +28,16 @@ import ( ) type regularFile struct { - inode Inode + inode inode mu sync.RWMutex data []byte // dataLen is len(data), but accessed using atomic memory operations to - // avoid locking in Inode.stat(). + // avoid locking in inode.stat(). dataLen int64 } -func (fs *Filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *Inode { +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode { file := ®ularFile{} file.inode.init(file, fs, creds, mode) file.inode.nlink = 1 // from parent directory @@ -46,7 +46,6 @@ func (fs *Filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *Inod type regularFileFD struct { fileDescription - vfs.FileDescriptionDefaultImpl // These are immutable. readable bool diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/memfs/symlink.go index e002d1727..b2ac2cbeb 100644 --- a/pkg/sentry/fsimpl/memfs/symlink.go +++ b/pkg/sentry/fsimpl/memfs/symlink.go @@ -19,11 +19,11 @@ import ( ) type symlink struct { - inode Inode + inode inode target string // immutable } -func (fs *Filesystem) newSymlink(creds *auth.Credentials, target string) *Inode { +func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode { link := &symlink{ target: target, } diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD new file mode 100644 index 000000000..3d8a4deaf --- /dev/null +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -0,0 +1,49 @@ +load("//tools/go_stateify:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "proc", + srcs = [ + "filesystems.go", + "loadavg.go", + "meminfo.go", + "mounts.go", + "net.go", + "proc.go", + "stat.go", + "sys.go", + "task.go", + "version.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc", + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/log", + "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/inet", + "//pkg/sentry/kernel", + "//pkg/sentry/limits", + "//pkg/sentry/mm", + "//pkg/sentry/socket", + "//pkg/sentry/socket/unix", + "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/usage", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + ], +) + +go_test( + name = "proc_test", + size = "small", + srcs = ["net_test.go"], + embed = [":proc"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/inet", + ], +) diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go new file mode 100644 index 000000000..c36c4aff5 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/filesystems.go @@ -0,0 +1,25 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +// filesystemsData implements vfs.DynamicBytesSource for /proc/filesystems. +// +// +stateify savable +type filesystemsData struct{} + +// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for +// filesystemsData. We would need to retrive filesystem names from +// vfs.VirtualFilesystem. Also needs vfs replacement for +// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev. diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go new file mode 100644 index 000000000..9135afef1 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/loadavg.go @@ -0,0 +1,40 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// loadavgData backs /proc/loadavg. +// +// +stateify savable +type loadavgData struct{} + +var _ vfs.DynamicBytesSource = (*loadavgData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/62345059): Include real data in fields. + // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods. + // Column 4-5: currently running processes and the total number of processes. + // Column 6: the last process ID used. + fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go new file mode 100644 index 000000000..9a827cd66 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/meminfo.go @@ -0,0 +1,77 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo. +// +// +stateify savable +type meminfoData struct { + // k is the owning Kernel. + k *kernel.Kernel +} + +var _ vfs.DynamicBytesSource = (*meminfoData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { + mf := d.k.MemoryFile() + mf.UpdateUsage() + snapshot, totalUsage := usage.MemoryAccounting.Copy() + totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage) + anon := snapshot.Anonymous + snapshot.Tmpfs + file := snapshot.PageCache + snapshot.Mapped + // We don't actually have active/inactive LRUs, so just make up numbers. + activeFile := (file / 2) &^ (usermem.PageSize - 1) + inactiveFile := file - activeFile + + fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024) + memFree := (totalSize - totalUsage) / 1024 + // We use MemFree as MemAvailable because we don't swap. + // TODO(rahat): When reclaim is implemented the value of MemAvailable + // should change. + fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree) + fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree) + fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices + fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024) + // Emulate a system with no swap, which disables inactivation of anon pages. + fmt.Fprintf(buf, "SwapCache: 0 kB\n") + fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024) + fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024) + fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024) + fmt.Fprintf(buf, "Inactive(anon): 0 kB\n") + fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024) + fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024) + fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263) + fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263) + fmt.Fprintf(buf, "SwapTotal: 0 kB\n") + fmt.Fprintf(buf, "SwapFree: 0 kB\n") + fmt.Fprintf(buf, "Dirty: 0 kB\n") + fmt.Fprintf(buf, "Writeback: 0 kB\n") + fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024) + fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know + fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go new file mode 100644 index 000000000..e81b1e910 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/mounts.go @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import "gvisor.dev/gvisor/pkg/sentry/kernel" + +// TODO(b/138862512): Implement mountInfoFile and mountsFile. + +// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo. +// +// +stateify savable +type mountInfoFile struct { + t *kernel.Task +} + +// mountsFile implements vfs.DynamicBytesSource for /proc/[pid]/mounts. +// +// +stateify savable +type mountsFile struct { + t *kernel.Task +} diff --git a/pkg/sentry/fsimpl/proc/net.go b/pkg/sentry/fsimpl/proc/net.go new file mode 100644 index 000000000..fd46eebf8 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/net.go @@ -0,0 +1,338 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/unix" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6. +// +// +stateify savable +type ifinet6 struct { + s inet.Stack +} + +var _ vfs.DynamicBytesSource = (*ifinet6)(nil) + +func (n *ifinet6) contents() []string { + var lines []string + nics := n.s.Interfaces() + for id, naddrs := range n.s.InterfaceAddrs() { + nic, ok := nics[id] + if !ok { + // NIC was added after NICNames was called. We'll just + // ignore it. + continue + } + + for _, a := range naddrs { + // IPv6 only. + if a.Family != linux.AF_INET6 { + continue + } + + // Fields: + // IPv6 address displayed in 32 hexadecimal chars without colons + // Netlink device number (interface index) in hexadecimal (use nic id) + // Prefix length in hexadecimal + // Scope value (use 0) + // Interface flags + // Device name + lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name)) + } + } + return lines +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (n *ifinet6) Generate(ctx context.Context, buf *bytes.Buffer) error { + for _, l := range n.contents() { + buf.WriteString(l) + } + return nil +} + +// netDev implements vfs.DynamicBytesSource for /proc/net/dev. +// +// +stateify savable +type netDev struct { + s inet.Stack +} + +var _ vfs.DynamicBytesSource = (*netDev)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (n *netDev) Generate(ctx context.Context, buf *bytes.Buffer) error { + interfaces := n.s.Interfaces() + buf.WriteString("Inter-| Receive | Transmit\n") + buf.WriteString(" face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n") + + for _, i := range interfaces { + // Implements the same format as + // net/core/net-procfs.c:dev_seq_printf_stats. + var stats inet.StatDev + if err := n.s.Statistics(&stats, i.Name); err != nil { + log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err) + continue + } + fmt.Fprintf( + buf, + "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n", + i.Name, + // Received + stats[0], // bytes + stats[1], // packets + stats[2], // errors + stats[3], // dropped + stats[4], // fifo + stats[5], // frame + stats[6], // compressed + stats[7], // multicast + // Transmitted + stats[8], // bytes + stats[9], // packets + stats[10], // errors + stats[11], // dropped + stats[12], // fifo + stats[13], // frame + stats[14], // compressed + stats[15], // multicast + ) + } + + return nil +} + +// netUnix implements vfs.DynamicBytesSource for /proc/net/unix. +// +// +stateify savable +type netUnix struct { + k *kernel.Kernel +} + +var _ vfs.DynamicBytesSource = (*netUnix)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (n *netUnix) Generate(ctx context.Context, buf *bytes.Buffer) error { + buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n") + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() + if s == nil { + log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) + continue + } + sfile := s.(*fs.File) + if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { + s.DecRef() + // Not a unix socket. + continue + } + sops := sfile.FileOperations.(*unix.SocketOperations) + + addr, err := sops.Endpoint().GetLocalAddress() + if err != nil { + log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err) + addr.Addr = "<unknown>" + } + + sockFlags := 0 + if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok { + if ce.Listening() { + // For unix domain sockets, linux reports a single flag + // value if the socket is listening, of __SO_ACCEPTCON. + sockFlags = linux.SO_ACCEPTCON + } + } + + // In the socket entry below, the value for the 'Num' field requires + // some consideration. Linux prints the address to the struct + // unix_sock representing a socket in the kernel, but may redact the + // value for unprivileged users depending on the kptr_restrict + // sysctl. + // + // One use for this field is to allow a privileged user to + // introspect into the kernel memory to determine information about + // a socket not available through procfs, such as the socket's peer. + // + // In gvisor, returning a pointer to our internal structures would + // be pointless, as it wouldn't match the memory layout for struct + // unix_sock, making introspection difficult. We could populate a + // struct unix_sock with the appropriate data, but even that + // requires consideration for which kernel version to emulate, as + // the definition of this struct changes over time. + // + // For now, we always redact this pointer. + fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d", + (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct. + sfile.ReadRefs()-1, // RefCount, don't count our own ref. + 0, // Protocol, always 0 for UDS. + sockFlags, // Flags. + sops.Endpoint().Type(), // Type. + sops.State(), // State. + sfile.InodeID(), // Inode. + ) + + // Path + if len(addr.Addr) != 0 { + if addr.Addr[0] == 0 { + // Abstract path. + fmt.Fprintf(buf, " @%s", string(addr.Addr[1:])) + } else { + fmt.Fprintf(buf, " %s", string(addr.Addr)) + } + } + fmt.Fprintf(buf, "\n") + + s.DecRef() + } + return nil +} + +// netTCP implements vfs.DynamicBytesSource for /proc/net/tcp. +// +// +stateify savable +type netTCP struct { + k *kernel.Kernel +} + +var _ vfs.DynamicBytesSource = (*netTCP)(nil) + +func (n *netTCP) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() + if s == nil { + log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock) + continue + } + sfile := s.(*fs.File) + sops, ok := sfile.FileOperations.(socket.Socket) + if !ok { + panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) + } + if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) { + s.DecRef() + // Not tcp4 sockets. + continue + } + + // Linux's documentation for the fields below can be found at + // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt. + // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock(). + // Note that the header doesn't contain labels for all the fields. + + // Field: sl; entry number. + fmt.Fprintf(buf, "%4d: ", se.ID) + + portBuf := make([]byte, 2) + + // Field: local_adddress. + var localAddr linux.SockAddrInet + if local, _, err := sops.GetSockName(t); err == nil { + localAddr = *local.(*linux.SockAddrInet) + } + binary.LittleEndian.PutUint16(portBuf, localAddr.Port) + fmt.Fprintf(buf, "%08X:%04X ", + binary.LittleEndian.Uint32(localAddr.Addr[:]), + portBuf) + + // Field: rem_address. + var remoteAddr linux.SockAddrInet + if remote, _, err := sops.GetPeerName(t); err == nil { + remoteAddr = *remote.(*linux.SockAddrInet) + } + binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port) + fmt.Fprintf(buf, "%08X:%04X ", + binary.LittleEndian.Uint32(remoteAddr.Addr[:]), + portBuf) + + // Field: state; socket state. + fmt.Fprintf(buf, "%02X ", sops.State()) + + // Field: tx_queue, rx_queue; number of packets in the transmit and + // receive queue. Unimplemented. + fmt.Fprintf(buf, "%08X:%08X ", 0, 0) + + // Field: tr, tm->when; timer active state and number of jiffies + // until timer expires. Unimplemented. + fmt.Fprintf(buf, "%02X:%08X ", 0, 0) + + // Field: retrnsmt; number of unrecovered RTO timeouts. + // Unimplemented. + fmt.Fprintf(buf, "%08X ", 0) + + // Field: uid. + uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx) + if err != nil { + log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) + fmt.Fprintf(buf, "%5d ", 0) + } else { + fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow())) + } + + // Field: timeout; number of unanswered 0-window probes. + // Unimplemented. + fmt.Fprintf(buf, "%8d ", 0) + + // Field: inode. + fmt.Fprintf(buf, "%8d ", sfile.InodeID()) + + // Field: refcount. Don't count the ref we obtain while deferencing + // the weakref to this socket. + fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1) + + // Field: Socket struct address. Redacted due to the same reason as + // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. + fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil)) + + // Field: retransmit timeout. Unimplemented. + fmt.Fprintf(buf, "%d ", 0) + + // Field: predicted tick of soft clock (delayed ACK control data). + // Unimplemented. + fmt.Fprintf(buf, "%d ", 0) + + // Field: (ack.quick<<1)|ack.pingpong, Unimplemented. + fmt.Fprintf(buf, "%d ", 0) + + // Field: sending congestion window, Unimplemented. + fmt.Fprintf(buf, "%d ", 0) + + // Field: Slow start size threshold, -1 if threshold >= 0xFFFF. + // Unimplemented, report as large threshold. + fmt.Fprintf(buf, "%d", -1) + + fmt.Fprintf(buf, "\n") + + s.DecRef() + } + + return nil +} diff --git a/pkg/sentry/fsimpl/proc/net_test.go b/pkg/sentry/fsimpl/proc/net_test.go new file mode 100644 index 000000000..20a77a8ca --- /dev/null +++ b/pkg/sentry/fsimpl/proc/net_test.go @@ -0,0 +1,78 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/inet" +) + +func newIPv6TestStack() *inet.TestStack { + s := inet.NewTestStack() + s.SupportsIPv6Flag = true + return s +} + +func TestIfinet6NoAddresses(t *testing.T) { + n := &ifinet6{s: newIPv6TestStack()} + var buf bytes.Buffer + n.Generate(contexttest.Context(t), &buf) + if buf.Len() > 0 { + t.Errorf("n.Generate() generated = %v, want = %v", buf.Bytes(), []byte{}) + } +} + +func TestIfinet6(t *testing.T) { + s := newIPv6TestStack() + s.InterfacesMap[1] = inet.Interface{Name: "eth0"} + s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{ + { + Family: linux.AF_INET6, + PrefixLen: 128, + Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"), + }, + } + s.InterfacesMap[2] = inet.Interface{Name: "eth1"} + s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{ + { + Family: linux.AF_INET6, + PrefixLen: 128, + Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"), + }, + } + want := map[string]struct{}{ + "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {}, + "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {}, + } + + n := &ifinet6{s: s} + contents := n.contents() + if len(contents) != len(want) { + t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want)) + } + got := map[string]struct{}{} + for _, l := range contents { + got[l] = struct{}{} + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("Got n.contents() = %v, want = %v", got, want) + } +} diff --git a/pkg/sentry/fsimpl/proc/proc.go b/pkg/sentry/fsimpl/proc/proc.go new file mode 100644 index 000000000..31dec36de --- /dev/null +++ b/pkg/sentry/fsimpl/proc/proc.go @@ -0,0 +1,16 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package proc implements a partial in-memory file system for procfs. +package proc diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go new file mode 100644 index 000000000..720db3828 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/stat.go @@ -0,0 +1,127 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// cpuStats contains the breakdown of CPU time for /proc/stat. +type cpuStats struct { + // user is time spent in userspace tasks with non-positive niceness. + user uint64 + + // nice is time spent in userspace tasks with positive niceness. + nice uint64 + + // system is time spent in non-interrupt kernel context. + system uint64 + + // idle is time spent idle. + idle uint64 + + // ioWait is time spent waiting for IO. + ioWait uint64 + + // irq is time spent in interrupt context. + irq uint64 + + // softirq is time spent in software interrupt context. + softirq uint64 + + // steal is involuntary wait time. + steal uint64 + + // guest is time spent in guests with non-positive niceness. + guest uint64 + + // guestNice is time spent in guests with positive niceness. + guestNice uint64 +} + +// String implements fmt.Stringer. +func (c cpuStats) String() string { + return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice) +} + +// statData implements vfs.DynamicBytesSource for /proc/stat. +// +// +stateify savable +type statData struct { + // k is the owning Kernel. + k *kernel.Kernel +} + +var _ vfs.DynamicBytesSource = (*statData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/37226836): We currently export only zero CPU stats. We could + // at least provide some aggregate stats. + var cpu cpuStats + fmt.Fprintf(buf, "cpu %s\n", cpu) + + for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ { + fmt.Fprintf(buf, "cpu%d %s\n", c, cpu) + } + + // The total number of interrupts is dependent on the CPUs and PCI + // devices on the system. See arch_probe_nr_irqs. + // + // Since we don't report real interrupt stats, just choose an arbitrary + // value from a representative VM. + const numInterrupts = 256 + + // The Kernel doesn't handle real interrupts, so report all zeroes. + // TODO(b/37226836): We could count page faults as #PF. + fmt.Fprintf(buf, "intr 0") // total + for i := 0; i < numInterrupts; i++ { + fmt.Fprintf(buf, " 0") + } + fmt.Fprintf(buf, "\n") + + // Total number of context switches. + // TODO(b/37226836): Count this. + fmt.Fprintf(buf, "ctxt 0\n") + + // CLOCK_REALTIME timestamp from boot, in seconds. + fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds()) + + // Total number of clones. + // TODO(b/37226836): Count this. + fmt.Fprintf(buf, "processes 0\n") + + // Number of runnable tasks. + // TODO(b/37226836): Count this. + fmt.Fprintf(buf, "procs_running 0\n") + + // Number of tasks waiting on IO. + // TODO(b/37226836): Count this. + fmt.Fprintf(buf, "procs_blocked 0\n") + + // Number of each softirq handled. + fmt.Fprintf(buf, "softirq 0") // total + for i := 0; i < linux.NumSoftIRQ; i++ { + fmt.Fprintf(buf, " 0") + } + fmt.Fprintf(buf, "\n") + return nil +} diff --git a/pkg/sentry/fsimpl/proc/sys.go b/pkg/sentry/fsimpl/proc/sys.go new file mode 100644 index 000000000..b88256e12 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/sys.go @@ -0,0 +1,51 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// mmapMinAddrData implements vfs.DynamicBytesSource for +// /proc/sys/vm/mmap_min_addr. +// +// +stateify savable +type mmapMinAddrData struct { + k *kernel.Kernel +} + +var _ vfs.DynamicBytesSource = (*mmapMinAddrData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress()) + return nil +} + +// +stateify savable +type overcommitMemory struct{} + +var _ vfs.DynamicBytesSource = (*overcommitMemory)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *overcommitMemory) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "0\n") + return nil +} diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go new file mode 100644 index 000000000..c46e05c3a --- /dev/null +++ b/pkg/sentry/fsimpl/proc/task.go @@ -0,0 +1,261 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/mm" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// mapsCommon is embedded by mapsData and smapsData. +type mapsCommon struct { + t *kernel.Task +} + +// mm gets the kernel task's MemoryManager. No additional reference is taken on +// mm here. This is safe because MemoryManager.destroy is required to leave the +// MemoryManager in a state where it's still usable as a DynamicBytesSource. +func (md *mapsCommon) mm() *mm.MemoryManager { + var tmm *mm.MemoryManager + md.t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + tmm = mm + } + }) + return tmm +} + +// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. +// +// +stateify savable +type mapsData struct { + mapsCommon +} + +var _ vfs.DynamicBytesSource = (*mapsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (md *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if mm := md.mm(); mm != nil { + mm.ReadMapsDataInto(ctx, buf) + } + return nil +} + +// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps. +// +// +stateify savable +type smapsData struct { + mapsCommon +} + +var _ vfs.DynamicBytesSource = (*smapsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (sd *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if mm := sd.mm(); mm != nil { + mm.ReadSmapsDataInto(ctx, buf) + } + return nil +} + +// +stateify savable +type taskStatData struct { + t *kernel.Task + + // If tgstats is true, accumulate fault stats (not implemented) and CPU + // time across all tasks in t's thread group. + tgstats bool + + // pidns is the PID namespace associated with the proc filesystem that + // includes the file using this statData. + pidns *kernel.PIDNamespace +} + +var _ vfs.DynamicBytesSource = (*taskStatData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t)) + fmt.Fprintf(buf, "(%s) ", s.t.Name()) + fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0]) + ppid := kernel.ThreadID(0) + if parent := s.t.Parent(); parent != nil { + ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) + } + fmt.Fprintf(buf, "%d ", ppid) + fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup())) + fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session())) + fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */) + fmt.Fprintf(buf, "0 " /* flags */) + fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */) + var cputime usage.CPUStats + if s.tgstats { + cputime = s.t.ThreadGroup().CPUStats() + } else { + cputime = s.t.CPUStats() + } + fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) + cputime = s.t.ThreadGroup().JoinedChildCPUStats() + fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime)) + fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness()) + fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count()) + + // itrealvalue. Since kernel 2.6.17, this field is no longer + // maintained, and is hard coded as 0. + fmt.Fprintf(buf, "0 ") + + // Start time is relative to boot time, expressed in clock ticks. + fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime()))) + + var vss, rss uint64 + s.t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } + }) + fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize) + + // rsslim. + fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur) + + fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */) + fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */) + fmt.Fprintf(buf, "0 0 " /* nswap cnswap */) + terminationSignal := linux.Signal(0) + if s.t == s.t.ThreadGroup().Leader() { + terminationSignal = s.t.ThreadGroup().TerminationSignal() + } + fmt.Fprintf(buf, "%d ", terminationSignal) + fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */) + fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */) + fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */) + fmt.Fprintf(buf, "0\n" /* exit_code */) + + return nil +} + +// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm. +// +// +stateify savable +type statmData struct { + t *kernel.Task +} + +var _ vfs.DynamicBytesSource = (*statmData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { + var vss, rss uint64 + s.t.WithMuLocked(func(t *kernel.Task) { + if mm := t.MemoryManager(); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + } + }) + + fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize) + return nil +} + +// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status. +// +// +stateify savable +type statusData struct { + t *kernel.Task + pidns *kernel.PIDNamespace +} + +var _ vfs.DynamicBytesSource = (*statusData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name()) + fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus()) + fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup())) + fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t)) + ppid := kernel.ThreadID(0) + if parent := s.t.Parent(); parent != nil { + ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) + } + fmt.Fprintf(buf, "PPid:\t%d\n", ppid) + tpid := kernel.ThreadID(0) + if tracer := s.t.Tracer(); tracer != nil { + tpid = s.pidns.IDOfTask(tracer) + } + fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid) + var fds int + var vss, rss, data uint64 + s.t.WithMuLocked(func(t *kernel.Task) { + if fdTable := t.FDTable(); fdTable != nil { + fds = fdTable.Size() + } + if mm := t.MemoryManager(); mm != nil { + vss = mm.VirtualMemorySize() + rss = mm.ResidentSetSize() + data = mm.VirtualDataSize() + } + }) + fmt.Fprintf(buf, "FDSize:\t%d\n", fds) + fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10) + fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10) + fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10) + fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count()) + creds := s.t.Credentials() + fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps) + fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps) + fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) + fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) + fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + return nil +} + +// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider. +type ioUsage interface { + // IOUsage returns the io usage data. + IOUsage() *usage.IO +} + +// +stateify savable +type ioData struct { + ioUsage +} + +var _ vfs.DynamicBytesSource = (*ioData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error { + io := usage.IO{} + io.Accumulate(i.IOUsage()) + + fmt.Fprintf(buf, "char: %d\n", io.CharsRead) + fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten) + fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls) + fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls) + fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead) + fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten) + fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go new file mode 100644 index 000000000..e1643d4e0 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/version.go @@ -0,0 +1,68 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// versionData implements vfs.DynamicBytesSource for /proc/version. +// +// +stateify savable +type versionData struct { + // k is the owning Kernel. + k *kernel.Kernel +} + +var _ vfs.DynamicBytesSource = (*versionData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { + init := v.k.GlobalInit() + if init == nil { + // Attempted to read before the init Task is created. This can + // only occur during startup, which should never need to read + // this file. + panic("Attempted to read version before initial Task is available") + } + + // /proc/version takes the form: + // + // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) + // (COMPILER_VERSION) VERSION" + // + // where: + // - SYSNAME, RELEASE, and VERSION are the same as returned by + // sys_utsname + // - COMPILE_USER is the user that build the kernel + // - COMPILE_HOST is the hostname of the machine on which the kernel + // was built + // - COMPILER_VERSION is the version reported by the building compiler + // + // Since we don't really want to expose build information to + // applications, those fields are omitted. + // + // FIXME(mpratt): Using Version from the init task SyscallTable + // disregards the different version a task may have (e.g., in a uts + // namespace). + ver := init.Leader().SyscallTable().Version + fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) + return nil +} diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 5b75a4a06..80f227dbe 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -52,12 +52,16 @@ type Stack interface { // Statistics reports stack statistics. Statistics(stat interface{}, arg string) error + + // RouteTable returns the network stack's route table. + RouteTable() []Route + + // Resume restarts the network stack after restore. + Resume() } // Interface contains information about a network interface. type Interface struct { - // Keep these fields sorted in the order they appear in rtnetlink(7). - // DeviceType is the device type, a Linux ARPHRD_* constant. DeviceType uint16 @@ -77,8 +81,6 @@ type Interface struct { // InterfaceAddr contains information about a network interface address. type InterfaceAddr struct { - // Keep these fields sorted in the order they appear in rtnetlink(7). - // Family is the address family, a Linux AF_* constant. Family uint8 @@ -109,3 +111,45 @@ type TCPBufferSize struct { // StatDev describes one line of /proc/net/dev, i.e., stats for one network // interface. type StatDev [16]uint64 + +// Route contains information about a network route. +type Route struct { + // Family is the address family, a Linux AF_* constant. + Family uint8 + + // DstLen is the length of the destination address. + DstLen uint8 + + // SrcLen is the length of the source address. + SrcLen uint8 + + // TOS is the Type of Service filter. + TOS uint8 + + // Table is the routing table ID. + Table uint8 + + // Protocol is the route origin, a Linux RTPROT_* constant. + Protocol uint8 + + // Scope is the distance to destination, a Linux RT_SCOPE_* constant. + Scope uint8 + + // Type is the route origin, a Linux RTN_* constant. + Type uint8 + + // Flags are route flags. See rtnetlink(7) under "rtm_flags". + Flags uint32 + + // DstAddr is the route destination address (RTA_DST). + DstAddr []byte + + // SrcAddr is the route source address (RTA_SRC). + SrcAddr []byte + + // OutputInterface is the output interface index (RTA_OIF). + OutputInterface int32 + + // GatewayAddr is the route gateway address (RTA_GATEWAY). + GatewayAddr []byte +} diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 75f9e7a77..b9eed7c3a 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -18,6 +18,7 @@ package inet type TestStack struct { InterfacesMap map[int32]Interface InterfaceAddrsMap map[int32][]InterfaceAddr + RouteList []Route SupportsIPv6Flag bool TCPRecvBufSize TCPBufferSize TCPSendBufSize TCPBufferSize @@ -86,3 +87,12 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil } + +// RouteTable implements Stack.RouteTable. +func (s *TestStack) RouteTable() []Route { + return s.RouteList +} + +// Resume implements Stack.Resume. +func (s *TestStack) Resume() { +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e61d39c82..41bee9a22 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -144,6 +144,7 @@ go_library( "threads.go", "timekeeper.go", "timekeeper_state.go", + "tty.go", "uts_namespace.go", "vdso.go", "version.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 4c2d48e65..8c1f79ab5 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -112,11 +112,6 @@ type Kernel struct { rootIPCNamespace *IPCNamespace rootAbstractSocketNamespace *AbstractSocketNamespace - // mounts holds the state of the virtual filesystem. mounts is initially - // nil, and must be set by calling Kernel.SetRootMountNamespace before - // Kernel.CreateProcess can succeed. - mounts *fs.MountNamespace - // futexes is the "root" futex.Manager, from which all others are forked. // This is necessary to ensure that shared futexes are coherent across all // tasks, including those created by CreateProcess. @@ -197,6 +192,15 @@ type Kernel struct { // caches. Not all caches use it, only the caches that use host resources use // the limiter. It may be nil if disabled. DirentCacheLimiter *fs.DirentCacheLimiter + + // unimplementedSyscallEmitterOnce is used in the initialization of + // unimplementedSyscallEmitter. + unimplementedSyscallEmitterOnce sync.Once `state:"nosave"` + + // unimplementedSyscallEmitter is used to emit unimplemented syscall + // events. This is initialized lazily on the first unimplemented + // syscall. + unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"` } // InitKernelArgs holds arguments to Init. @@ -290,7 +294,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() - return nil } @@ -384,11 +387,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. func (k *Kernel) flushMountSourceRefs() error { - // Flush all mount sources for currently mounted filesystems in the - // root mount namespace. - k.mounts.FlushMountSourceRefs() - - // Some tasks may have other mount namespaces; flush those as well. + // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() k.tasks.forEachThreadGroupLocked(func(tg *ThreadGroup) { @@ -497,7 +496,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { +func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() k.networkStack = net @@ -541,6 +540,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { log.Infof("Overall load took [%s]", time.Since(loadStart)) + k.Timekeeper().SetClocks(clocks) + if net != nil { + net.Resume() + } + // Ensure that all pending asynchronous work is complete: // - namedpipe opening // - inode file opening @@ -550,7 +554,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { tcpip.AsyncLoading.Wait() - log.Infof("Overall load took [%s]", time.Since(loadStart)) + log.Infof("Overall load took [%s] after async work", time.Since(loadStart)) // Applications may size per-cpu structures based on k.applicationCores, so // it can't change across save/restore. When we are virtualizing CPU @@ -565,16 +569,6 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error { return nil } -// Destroy releases resources owned by k. -// -// Preconditions: There must be no task goroutines running in k. -func (k *Kernel) Destroy() { - if k.mounts != nil { - k.mounts.DecRef() - k.mounts = nil - } -} - // UniqueID returns a unique identifier. func (k *Kernel) UniqueID() uint64 { id := atomic.AddUint64(&k.uniqueID, 1) @@ -586,11 +580,17 @@ func (k *Kernel) UniqueID() uint64 { // CreateProcessArgs holds arguments to kernel.CreateProcess. type CreateProcessArgs struct { - // Filename is the filename to load. + // Filename is the filename to load as the init binary. // - // If this is provided as "", then the file will be guessed via Argv[0]. + // If this is provided as "", File will be checked, then the file will be + // guessed via Argv[0]. Filename string + // File is a passed host FD pointing to a file to load as the init binary. + // + // This is checked if and only if Filename is "". + File *fs.File + // Argvv is a list of arguments. Argv []string @@ -632,19 +632,12 @@ type CreateProcessArgs struct { AbstractSocketNamespace *AbstractSocketNamespace // MountNamespace optionally contains the mount namespace for this - // process. If nil, the kernel's mount namespace is used. + // process. If nil, the init process's mount namespace is used. // // Anyone setting MountNamespace must donate a reference (i.e. // increment it). MountNamespace *fs.MountNamespace - // Root optionally contains the dirent that serves as the root for the - // process. If nil, the mount namespace's root is used as the process' - // root. - // - // Anyone setting Root must donate a reference (i.e. increment it). - Root *fs.Dirent - // ContainerID is the container that the process belongs to. ContainerID string } @@ -682,16 +675,10 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} { case auth.CtxCredentials: return ctx.args.Credentials case fs.CtxRoot: - if ctx.args.Root != nil { - // Take a reference on the root dirent that will be - // given to the caller. - ctx.args.Root.IncRef() - return ctx.args.Root - } - if ctx.k.mounts != nil { - // MountNamespace.Root() will take a reference on the - // root dirent for us. - return ctx.k.mounts.Root() + if ctx.args.MountNamespace != nil { + // MountNamespace.Root() will take a reference on the root + // dirent for us. + return ctx.args.MountNamespace.Root() } return nil case fs.CtxDirentCacheLimiter: @@ -735,30 +722,18 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, defer k.extMu.Unlock() log.Infof("EXEC: %v", args.Argv) - if k.mounts == nil { - return nil, 0, fmt.Errorf("no kernel MountNamespace") - } - // Grab the mount namespace. mounts := args.MountNamespace if mounts == nil { - // If no MountNamespace was configured, then use the kernel's - // root mount namespace, with an extra reference that will be - // donated to the task. - mounts = k.mounts + mounts = k.GlobalInit().Leader().MountNamespace() mounts.IncRef() } tg := k.newThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock) ctx := args.NewContext(k) - // Grab the root directory. - root := args.Root - if root == nil { - // If no Root was configured, then get it from the - // MountNamespace. - root = mounts.Root() - } + // Get the root directory from the MountNamespace. + root := mounts.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. defer root.DecRef() @@ -768,15 +743,23 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, wd := root // Default. if args.WorkingDirectory != "" { var err error - wd, err = k.mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals) + wd, err = mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals) if err != nil { return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err) } defer wd.DecRef() } - if args.Filename == "" { - // Was anything provided? + // Check which file to start from. + switch { + case args.Filename != "": + // If a filename is given, take that. + // Set File to nil so we resolve the path in LoadTaskImage. + args.File = nil + case args.File != nil: + // If File is set, take the File provided directly. + default: + // Otherwise look at Argv and see if the first argument is a valid path. if len(args.Argv) == 0 { return nil, 0, fmt.Errorf("no filename or command provided") } @@ -788,7 +771,8 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) - tc, se := k.LoadTaskImage(ctx, k.mounts, root, wd, &remainingTraversals, args.Filename, args.Argv, args.Envv, k.featureSet) + + tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) if se != nil { return nil, 0, errors.New(se.String()) } @@ -1032,20 +1016,6 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace { return k.rootAbstractSocketNamespace } -// RootMountNamespace returns the MountNamespace. -func (k *Kernel) RootMountNamespace() *fs.MountNamespace { - k.extMu.Lock() - defer k.extMu.Unlock() - return k.mounts -} - -// SetRootMountNamespace sets the MountNamespace. -func (k *Kernel) SetRootMountNamespace(mounts *fs.MountNamespace) { - k.extMu.Lock() - defer k.extMu.Unlock() - k.mounts = mounts -} - // NetworkStack returns the network stack. NetworkStack may return nil if no // network stack is available. func (k *Kernel) NetworkStack() inet.Stack { @@ -1168,16 +1138,6 @@ func (k *Kernel) SupervisorContext() context.Context { } } -// EmitUnimplementedEvent emits an UnimplementedSyscall event via the event -// channel. -func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { - t := TaskFromContext(ctx) - eventchannel.Emit(&uspb.UnimplementedSyscall{ - Tid: int32(t.ThreadID()), - Registers: t.Arch().StateData().Proto(), - }) -} - // SocketEntry represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // @@ -1246,7 +1206,10 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { // The supervisor context is global root. return auth.NewRootCredentials(ctx.k.rootUserNamespace) case fs.CtxRoot: - return ctx.k.mounts.Root() + if ctx.k.globalInit != nil { + return ctx.k.globalInit.mounts.Root() + } + return nil case fs.CtxDirentCacheLimiter: return ctx.k.DirentCacheLimiter case ktime.CtxRealtimeClock: @@ -1272,3 +1235,23 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { return nil } } + +// Rate limits for the number of unimplemented syscall events. +const ( + unimplementedSyscallsMaxRate = 100 // events per second + unimplementedSyscallBurst = 1000 // events +) + +// EmitUnimplementedEvent emits an UnimplementedSyscall event via the event +// channel. +func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { + k.unimplementedSyscallEmitterOnce.Do(func() { + k.unimplementedSyscallEmitter = eventchannel.RateLimitedEmitterFrom(eventchannel.DefaultEmitter, unimplementedSyscallsMaxRate, unimplementedSyscallBurst) + }) + + t := TaskFromContext(ctx) + k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{ + Tid: int32(t.ThreadID()), + Registers: t.Arch().StateData().Proto(), + }) +} diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 81fcd8258..e5f297478 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -47,6 +47,11 @@ type Session struct { // The id is immutable. id SessionID + // foreground is the foreground process group. + // + // This is protected by TaskSet.mu. + foreground *ProcessGroup + // ProcessGroups is a list of process groups in this Session. This is // protected by TaskSet.mu. processGroups processGroupList @@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { func (tg *ThreadGroup) CreateSession() error { tg.pidns.owner.mu.Lock() defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() return tg.createSession() } // createSession creates a new session for a threadgroup. // -// Precondition: callers must hold TaskSet.mu for writing. +// Precondition: callers must hold TaskSet.mu and the signal mutex for writing. func (tg *ThreadGroup) createSession() error { // Get the ID for this thread in the current namespace. id := tg.pidns.tgids[tg] @@ -346,6 +353,9 @@ func (tg *ThreadGroup) createSession() error { ns.processGroups[ProcessGroupID(local)] = pg } + // Disconnect from the controlling terminal. + tg.tty = nil + return nil } diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 2a2e6f662..dd69939f9 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -15,6 +15,7 @@ package kernel import ( + "runtime" "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -121,6 +122,17 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { // Deactive our address space, we don't need it. interrupt := t.SleepStart() + // If the request is not completed, but the timer has already expired, + // then ensure that we run through a scheduler cycle. This is because + // we may see applications relying on timer slack to yield the thread. + // For example, they may attempt to sleep for some number of nanoseconds, + // and expect that this will actually yield the CPU and sleep for at + // least microseconds, e.g.: + // https://github.com/LMAX-Exchange/disruptor/commit/6ca210f2bcd23f703c479804d583718e16f43c07 + if len(timerChan) > 0 { + runtime.Gosched() + } + select { case <-C: t.SleepFinish(true) diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 54b1676b0..8639d379f 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -140,15 +140,22 @@ func (t *Task) Stack() *arch.Stack { // * wd: Working directory to lookup filename under // * maxTraversals: maximum number of symlinks to follow // * filename: path to binary to load +// * file: an open fs.File object of the binary to load. If set, +// file will be loaded and not filename. // * argv: Binary argv // * envv: Binary envv // * fs: Binary FeatureSet -func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { +func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { + // If File is not nil, we should load that instead of resolving filename. + if file != nil { + filename = file.MappedName(ctx) + } + // Prepare a new user address space to load into. m := mm.NewMemoryManager(k, k) defer m.DecUsers(ctx) - os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, argv, envv, k.extraAuxv, k.vdso) + os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso) if err != nil { return nil, err } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index d60cd62c7..ae6fc4025 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { if parentPG := tg.parentPG(); parentPG == nil { tg.createSession() } else { - // Inherit the process group. + // Inherit the process group and terminal. parentPG.incRefWithParent(parentPG) tg.processGroup = parentPG + tg.tty = t.parent.tg.tty } } tg.tasks.PushBack(t) diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 2a97e3e8e..0eef24bfb 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -19,10 +19,13 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/syserror" ) // A ThreadGroup is a logical grouping of tasks that has widespread @@ -245,6 +248,12 @@ type ThreadGroup struct { // // mounts is immutable. mounts *fs.MountNamespace + + // tty is the thread group's controlling terminal. If nil, there is no + // controlling terminal. + // + // tty is protected by the signal mutex. + tty *TTY } // newThreadGroup returns a new, empty thread group in PID namespace ns. The @@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) { } } +// SetControllingTTY sets tty as the controlling terminal of tg. +func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "The calling process must be a session leader and not have a + // controlling terminal already." - tty_ioctl(4) + if tg.processGroup.session.leader != tg || tg.tty != nil { + return syserror.EINVAL + } + + // "If this terminal is already the controlling terminal of a different + // session group, then the ioctl fails with EPERM, unless the caller + // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the + // terminal is stolen, and all processes that had it as controlling + // terminal lose it." - tty_ioctl(4) + if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session { + if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 { + return syserror.EPERM + } + // Steal the TTY away. Unlike TIOCNOTTY, don't send signals. + for othertg := range tg.pidns.owner.Root.tgids { + // This won't deadlock by locking tg.signalHandlers + // because at this point: + // - We only lock signalHandlers if it's in the same + // session as the tty's controlling thread group. + // - We know that the calling thread group is not in + // the same session as the tty's controlling thread + // group. + if othertg.processGroup.session == tty.tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + othertg.signalHandlers.mu.Unlock() + } + } + } + + // Set the controlling terminal and foreground process group. + tg.tty = tty + tg.processGroup.session.foreground = tg.processGroup + // Set this as the controlling process of the terminal. + tty.tg = tg + + return nil +} + +// ReleaseControllingTTY gives up tty as the controlling tty of tg. +func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.RLock() + defer tg.pidns.owner.mu.RUnlock() + + // Just below, we may re-lock signalHandlers in order to send signals. + // Thus we can't defer Unlock here. + tg.signalHandlers.mu.Lock() + + if tg.tty == nil || tg.tty != tty { + tg.signalHandlers.mu.Unlock() + return syserror.ENOTTY + } + + // "If the process was session leader, then send SIGHUP and SIGCONT to + // the foreground process group and all processes in the current + // session lose their controlling terminal." - tty_ioctl(4) + // Remove tty as the controlling tty for each process in the session, + // then send them SIGHUP and SIGCONT. + + // If we're not the session leader, we don't have to do much. + if tty.tg != tg { + tg.tty = nil + tg.signalHandlers.mu.Unlock() + return nil + } + + tg.signalHandlers.mu.Unlock() + + // We're the session leader. SIGHUP and SIGCONT the foreground process + // group and remove all controlling terminals in the session. + var lastErr error + for othertg := range tg.pidns.owner.Root.tgids { + if othertg.processGroup.session == tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + if othertg.processGroup == tg.processGroup.session.foreground { + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { + lastErr = err + } + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { + lastErr = err + } + } + othertg.signalHandlers.mu.Unlock() + } + } + + return lastErr +} + +// ForegroundProcessGroup returns the process group ID of the foreground +// process group. +func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "When fd does not refer to the controlling terminal of the calling + // process, -1 is returned" - tcgetpgrp(3) + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + return int32(tg.processGroup.session.foreground.id), nil +} + +// SetForegroundProcessGroup sets the foreground process group of tty to pgid. +func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // TODO(b/129283598): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is + // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all + // members of this background process group." + + // tty must be the controlling terminal. + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + // pgid must be positive. + if pgid < 0 { + return -1, syserror.EINVAL + } + + // pg must not be empty. Empty process groups are removed from their + // pid namespaces. + pg, ok := tg.pidns.processGroups[pgid] + if !ok { + return -1, syserror.ESRCH + } + + // pg must be part of this process's session. + if tg.processGroup.session != pg.session { + return -1, syserror.EPERM + } + + tg.processGroup.session.foreground.id = pgid + return 0, nil +} + // itimerRealListener implements ktime.Listener for ITIMER_REAL expirations. // // +stateify savable diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go new file mode 100644 index 000000000..34f84487a --- /dev/null +++ b/pkg/sentry/kernel/tty.go @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync" + +// TTY defines the relationship between a thread group and its controlling +// terminal. +// +// +stateify savable +type TTY struct { + mu sync.Mutex `state:"nosave"` + + // tg is protected by mu. + tg *ThreadGroup +} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index bc5b841fb..ba9c9ce12 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -464,7 +464,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el // base address big enough to fit all segments, so we first create a // mapping for the total size just to find a region that is big enough. // - // It is safe to unmap it immediately with racing with another mapping + // It is safe to unmap it immediately without racing with another mapping // because we are the only one in control of the MemoryManager. // // Note that the vaddr of the first PT_LOAD segment is ignored when diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index baa12d9a0..f6f1ae762 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -67,8 +67,64 @@ func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, m if err != nil { return nil, nil, err } + + // Open file will take a reference to Dirent, so destroy this one. defer d.DecRef() + return openFile(ctx, nil, d, name) +} + +// openFile performs checks on a file to be executed. If provided a *fs.File, +// openFile takes that file's Dirent and performs checks on it. If provided a +// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's +// Inode and performs checks on that. +// +// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership +// of both. +// +// "dirent" and "file" must not both be nil and point to a readable, executable, regular file. +func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) { + // file and dirent must not be nil. + if dirent == nil && file == nil { + ctx.Infof("dirent and file cannot both be nil.") + return nil, nil, syserror.ENOENT + } + + if file != nil { + dirent = file.Dirent + } + + // Perform permissions checks on the file. + if err := checkFile(ctx, dirent, name); err != nil { + return nil, nil, err + } + + if file == nil { + var ferr error + if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil { + return nil, nil, ferr + } + } else { + // GetFile takes a reference to the created file, so make one in the case + // that the file reference already existed. + file.IncRef() + } + + // We must be able to read at arbitrary offsets. + if !file.Flags().Pread { + file.DecRef() + ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags()) + return nil, nil, syserror.EACCES + } + + // Grab reference for caller. + dirent.IncRef() + return dirent, file, nil +} + +// checkFile performs file permissions checks for binaries called in openPath +// and openFile +func checkFile(ctx context.Context, d *fs.Dirent, name string) error { perms := fs.PermMask{ // TODO(gvisor.dev/issue/160): Linux requires only execute // permission, not read. However, our backing filesystems may @@ -80,7 +136,7 @@ func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, m Execute: true, } if err := d.Inode.CheckPermission(ctx, perms); err != nil { - return nil, nil, err + return err } // If they claim it's a directory, then make sure. @@ -88,31 +144,17 @@ func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, m // N.B. we reject directories below, but we must first reject // non-directories passed as directories. if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) { - return nil, nil, syserror.ENOTDIR + return syserror.ENOTDIR } // No exec-ing directories, pipes, etc! if !fs.IsRegular(d.Inode.StableAttr) { ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr) - return nil, nil, syserror.EACCES + return syserror.EACCES } - // Create a new file. - file, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - return nil, nil, err - } + return nil - // We must be able to read at arbitrary offsets. - if !file.Flags().Pread { - file.DecRef() - ctx.Infof("%s cannot be read at an offset: %+v", name, file.Flags()) - return nil, nil, syserror.EACCES - } - - // Grab a reference for the caller. - d.IncRef() - return d, file, nil } // allocStack allocates and maps a stack in to any available part of the address space. @@ -131,16 +173,30 @@ const ( maxLoaderAttempts = 6 ) -// loadPath resolves filename to a binary and loads it. +// loadBinary loads a binary that is pointed to by "file". If nil, the path +// "filename" is resolved and loaded. // // It returns: // * loadedELF, description of the loaded binary // * arch.Context matching the binary arch // * fs.Dirent of the binary file // * Possibly updated argv -func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, fs *cpuid.FeatureSet, filename string, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { +func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { for i := 0; i < maxLoaderAttempts; i++ { - d, f, err := openPath(ctx, mounts, root, wd, remainingTraversals, filename) + var ( + d *fs.Dirent + f *fs.File + err error + ) + if passedFile == nil { + d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename) + + } else { + d, f, err = openFile(ctx, passedFile, nil, "") + // Set to nil in case we loop on a Interpreter Script. + passedFile = nil + } + if err != nil { ctx.Infof("Error opening %s: %v", filename, err) return loadedELF{}, nil, nil, nil, err @@ -165,7 +221,7 @@ func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespac switch { case bytes.Equal(hdr[:], []byte(elfMagic)): - loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, fs, f) + loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f) if err != nil { ctx.Infof("Error loading ELF: %v", err) return loadedELF{}, nil, nil, nil, err @@ -190,7 +246,8 @@ func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespac return loadedELF{}, nil, nil, nil, syserror.ELOOP } -// Load loads filename into a MemoryManager. +// Load loads "file" into a MemoryManager. If file is nil, the path "filename" +// is resolved and loaded instead. // // If Load returns ErrSwitchFile it should be called again with the returned // path and argv. @@ -198,9 +255,9 @@ func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespac // Preconditions: // * The Task MemoryManager is empty. // * Load is called on the Task goroutine. -func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { +func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { // Load the binary itself. - loaded, ac, d, argv, err := loadPath(ctx, m, mounts, root, wd, maxTraversals, fs, filename, argv) + loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux()) } diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go index a8819aa84..8c2246bb4 100644 --- a/pkg/sentry/mm/procfs.go +++ b/pkg/sentry/mm/procfs.go @@ -58,6 +58,34 @@ func (mm *MemoryManager) NeedsUpdate(generation int64) bool { return true } +// ReadMapsDataInto is called by fsimpl/proc.mapsData.Generate to +// implement /proc/[pid]/maps. +func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer) { + mm.mappingMu.RLock() + defer mm.mappingMu.RUnlock() + var start usermem.Addr + + for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { + // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get + // "panic: autosave error: type usermem.Addr is not registered". + mm.appendVMAMapsEntryLocked(ctx, vseg, buf) + } + + // We always emulate vsyscall, so advertise it here. Everything about a + // vsyscall region is static, so just hard code the maps entry since we + // don't have a real vma backing it. The vsyscall region is at the end of + // the virtual address space so nothing should be mapped after it (if + // something is really mapped in the tiny ~10 MiB segment afterwards, we'll + // get the sorting on the maps file wrong at worst; but that's not possible + // on any current platform). + // + // Artifically adjust the seqfile handle so we only output vsyscall entry once. + if start != vsyscallEnd { + // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd. + buf.WriteString(vsyscallMapsEntry) + } +} + // ReadMapsSeqFileData is called by fs/proc.mapsData.ReadSeqFileData to // implement /proc/[pid]/maps. func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) { @@ -151,6 +179,27 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI b.WriteString("\n") } +// ReadSmapsDataInto is called by fsimpl/proc.smapsData.Generate to +// implement /proc/[pid]/maps. +func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffer) { + mm.mappingMu.RLock() + defer mm.mappingMu.RUnlock() + var start usermem.Addr + + for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() { + // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get + // "panic: autosave error: type usermem.Addr is not registered". + mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf) + } + + // We always emulate vsyscall, so advertise it here. See + // ReadMapsSeqFileData for additional commentary. + if start != vsyscallEnd { + // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd. + buf.WriteString(vsyscallSmapsEntry) + } +} + // ReadSmapsSeqFileData is called by fs/proc.smapsData.ReadSeqFileData to // implement /proc/[pid]/smaps. func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) { @@ -190,7 +239,12 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterator) []byte { var b bytes.Buffer - mm.appendVMAMapsEntryLocked(ctx, vseg, &b) + mm.vmaSmapsEntryIntoLocked(ctx, vseg, &b) + return b.Bytes() +} + +func (mm *MemoryManager) vmaSmapsEntryIntoLocked(ctx context.Context, vseg vmaIterator, b *bytes.Buffer) { + mm.appendVMAMapsEntryLocked(ctx, vseg, b) vma := vseg.ValuePtr() // We take mm.activeMu here in each call to vmaSmapsEntryLocked, instead of @@ -211,40 +265,40 @@ func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterat } mm.activeMu.RUnlock() - fmt.Fprintf(&b, "Size: %8d kB\n", vseg.Range().Length()/1024) - fmt.Fprintf(&b, "Rss: %8d kB\n", rss/1024) + fmt.Fprintf(b, "Size: %8d kB\n", vseg.Range().Length()/1024) + fmt.Fprintf(b, "Rss: %8d kB\n", rss/1024) // Currently we report PSS = RSS, i.e. we pretend each page mapped by a pma // is only mapped by that pma. This avoids having to query memmap.Mappables // for reference count information on each page. As a corollary, all pages // are accounted as "private" whether or not the vma is private; compare // Linux's fs/proc/task_mmu.c:smaps_account(). - fmt.Fprintf(&b, "Pss: %8d kB\n", rss/1024) - fmt.Fprintf(&b, "Shared_Clean: %8d kB\n", 0) - fmt.Fprintf(&b, "Shared_Dirty: %8d kB\n", 0) + fmt.Fprintf(b, "Pss: %8d kB\n", rss/1024) + fmt.Fprintf(b, "Shared_Clean: %8d kB\n", 0) + fmt.Fprintf(b, "Shared_Dirty: %8d kB\n", 0) // Pretend that all pages are dirty if the vma is writable, and clean otherwise. clean := rss if vma.effectivePerms.Write { clean = 0 } - fmt.Fprintf(&b, "Private_Clean: %8d kB\n", clean/1024) - fmt.Fprintf(&b, "Private_Dirty: %8d kB\n", (rss-clean)/1024) + fmt.Fprintf(b, "Private_Clean: %8d kB\n", clean/1024) + fmt.Fprintf(b, "Private_Dirty: %8d kB\n", (rss-clean)/1024) // Pretend that all pages are "referenced" (recently touched). - fmt.Fprintf(&b, "Referenced: %8d kB\n", rss/1024) - fmt.Fprintf(&b, "Anonymous: %8d kB\n", anon/1024) + fmt.Fprintf(b, "Referenced: %8d kB\n", rss/1024) + fmt.Fprintf(b, "Anonymous: %8d kB\n", anon/1024) // Hugepages (hugetlb and THP) are not implemented. - fmt.Fprintf(&b, "AnonHugePages: %8d kB\n", 0) - fmt.Fprintf(&b, "Shared_Hugetlb: %8d kB\n", 0) - fmt.Fprintf(&b, "Private_Hugetlb: %7d kB\n", 0) + fmt.Fprintf(b, "AnonHugePages: %8d kB\n", 0) + fmt.Fprintf(b, "Shared_Hugetlb: %8d kB\n", 0) + fmt.Fprintf(b, "Private_Hugetlb: %7d kB\n", 0) // Swap is not implemented. - fmt.Fprintf(&b, "Swap: %8d kB\n", 0) - fmt.Fprintf(&b, "SwapPss: %8d kB\n", 0) - fmt.Fprintf(&b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024) - fmt.Fprintf(&b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024) + fmt.Fprintf(b, "Swap: %8d kB\n", 0) + fmt.Fprintf(b, "SwapPss: %8d kB\n", 0) + fmt.Fprintf(b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024) + fmt.Fprintf(b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024) locked := rss if vma.mlockMode == memmap.MLockNone { locked = 0 } - fmt.Fprintf(&b, "Locked: %8d kB\n", locked/1024) + fmt.Fprintf(b, "Locked: %8d kB\n", locked/1024) b.WriteString("VmFlags: ") if vma.realPerms.Read { @@ -284,6 +338,4 @@ func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterat b.WriteString("ac ") } b.WriteString("\n") - - return b.Bytes() } diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 8bd3e885d..f7f7298c4 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -285,7 +285,10 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) { switch opts.DelayedEviction { case DelayedEvictionDefault: opts.DelayedEviction = DelayedEvictionEnabled - case DelayedEvictionDisabled, DelayedEvictionEnabled, DelayedEvictionManual: + case DelayedEvictionDisabled, DelayedEvictionManual: + opts.UseHostMemcgPressure = false + case DelayedEvictionEnabled: + // ok default: return nil, fmt.Errorf("invalid MemoryFileOpts.DelayedEviction: %v", opts.DelayedEviction) } @@ -777,6 +780,14 @@ func (f *MemoryFile) MarkAllUnevictable(user EvictableMemoryUser) { } } +// ShouldCacheEvictable returns true if f is meaningfully delaying evictions of +// evictable memory, such that it may be advantageous to cache data in +// evictable memory. The value returned by ShouldCacheEvictable may change +// between calls. +func (f *MemoryFile) ShouldCacheEvictable() bool { + return f.opts.DelayedEviction == DelayedEvictionManual || f.opts.UseHostMemcgPressure +} + // UpdateUsage ensures that the memory usage statistics in // usage.MemoryAccounting are up to date. func (f *MemoryFile) UpdateUsage() error { diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 1b6c54e96..ebcc8c098 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -7,13 +7,17 @@ go_library( srcs = [ "filters.go", "ptrace.go", + "ptrace_amd64.go", + "ptrace_arm64.go", "ptrace_unsafe.go", "stub_amd64.s", + "stub_arm64.s", "stub_unsafe.go", "subprocess.go", "subprocess_amd64.go", + "subprocess_arm64.go", "subprocess_linux.go", - "subprocess_linux_amd64_unsafe.go", + "subprocess_linux_unsafe.go", "subprocess_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace", diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 6fd30ed25..7b120a15d 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -60,7 +60,7 @@ var ( // maximum user address. This is valid only after a call to stubInit. // // We attempt to link the stub here, and adjust downward as needed. - stubStart uintptr = 0x7fffffff0000 + stubStart uintptr = stubInitAddress // stubEnd is the first byte past the end of the stub, as with // stubStart this is valid only after a call to stubInit. diff --git a/pkg/flipcall/endpoint_futex.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go index 5cab02b1d..db0212538 100644 --- a/pkg/flipcall/endpoint_futex.go +++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go @@ -12,34 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package flipcall +package ptrace import ( - "fmt" -) + "syscall" -type endpointControlState struct{} + "gvisor.dev/gvisor/pkg/abi/linux" +) -func (ep *Endpoint) initControlState(ctrlMode ControlMode) error { - if ctrlMode != ControlModeFutex { - return fmt.Errorf("unsupported control mode: %v", ctrlMode) +// fpRegSet returns the GETREGSET/SETREGSET register set type to be used. +func fpRegSet(useXsave bool) uintptr { + if useXsave { + return linux.NT_X86_XSTATE } - return nil -} - -func (ep *Endpoint) doRoundTrip() error { - return ep.doFutexRoundTrip() -} - -func (ep *Endpoint) doWaitFirst() error { - return ep.doFutexWaitFirst() -} - -func (ep *Endpoint) doNotifyLast() error { - return ep.doFutexNotifyLast() + return linux.NT_PRFPREG } -// Preconditions: ep.isShutdown() == true. -func (ep *Endpoint) interruptForShutdown() { - ep.doFutexInterruptForShutdown() +func stackPointer(r *syscall.PtraceRegs) uintptr { + return uintptr(r.Rsp) } diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64.go b/pkg/sentry/platform/ptrace/ptrace_arm64.go new file mode 100644 index 000000000..4db28c534 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_arm64.go @@ -0,0 +1,30 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ptrace + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// fpRegSet returns the GETREGSET/SETREGSET register set type to be used. +func fpRegSet(_ bool) uintptr { + return linux.NT_PRFPREG +} + +func stackPointer(r *syscall.PtraceRegs) uintptr { + return uintptr(r.Sp) +} diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 2706039a5..47957bb3b 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -18,37 +18,23 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/usermem" ) -// GETREGSET/SETREGSET register set types. -// -// See include/uapi/linux/elf.h. -const ( - // _NT_PRFPREG is for x86 floating-point state without using xsave. - _NT_PRFPREG = 0x2 - - // _NT_X86_XSTATE is for x86 extended state using xsave. - _NT_X86_XSTATE = 0x202 -) - -// fpRegSet returns the GETREGSET/SETREGSET register set type to be used. -func fpRegSet(useXsave bool) uintptr { - if useXsave { - return _NT_X86_XSTATE - } - return _NT_PRFPREG -} - -// getRegs sets the regular register set. +// getRegs gets the general purpose register set. func (t *thread) getRegs(regs *syscall.PtraceRegs) error { + iovec := syscall.Iovec{ + Base: (*byte)(unsafe.Pointer(regs)), + Len: uint64(unsafe.Sizeof(*regs)), + } _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, - syscall.PTRACE_GETREGS, + syscall.PTRACE_GETREGSET, uintptr(t.tid), - 0, - uintptr(unsafe.Pointer(regs)), + linux.NT_PRSTATUS, + uintptr(unsafe.Pointer(&iovec)), 0, 0) if errno != 0 { return errno @@ -56,14 +42,18 @@ func (t *thread) getRegs(regs *syscall.PtraceRegs) error { return nil } -// setRegs sets the regular register set. +// setRegs sets the general purpose register set. func (t *thread) setRegs(regs *syscall.PtraceRegs) error { + iovec := syscall.Iovec{ + Base: (*byte)(unsafe.Pointer(regs)), + Len: uint64(unsafe.Sizeof(*regs)), + } _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, - syscall.PTRACE_SETREGS, + syscall.PTRACE_SETREGSET, uintptr(t.tid), - 0, - uintptr(unsafe.Pointer(regs)), + linux.NT_PRSTATUS, + uintptr(unsafe.Pointer(&iovec)), 0, 0) if errno != 0 { return errno @@ -131,7 +121,7 @@ func (t *thread) getSignalInfo(si *arch.SignalInfo) error { // // Precondition: the OS thread must be locked and own t. func (t *thread) clone() (*thread, error) { - r, ok := usermem.Addr(t.initRegs.Rsp).RoundUp() + r, ok := usermem.Addr(stackPointer(&t.initRegs)).RoundUp() if !ok { return nil, syscall.EINVAL } diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s new file mode 100644 index 000000000..2c5e4d5cb --- /dev/null +++ b/pkg/sentry/platform/ptrace/stub_arm64.s @@ -0,0 +1,106 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +#define SYS_GETPID 172 +#define SYS_EXIT 93 +#define SYS_KILL 129 +#define SYS_GETPPID 173 +#define SYS_PRCTL 167 + +#define SIGKILL 9 +#define SIGSTOP 19 + +#define PR_SET_PDEATHSIG 1 + +// stub bootstraps the child and sends itself SIGSTOP to wait for attach. +// +// R7 contains the expected PPID. +// +// This should not be used outside the context of a new ptrace child (as the +// function is otherwise a bunch of nonsense). +TEXT ·stub(SB),NOSPLIT,$0 +begin: + // N.B. This loop only executes in the context of a single-threaded + // fork child. + + MOVD $SYS_PRCTL, R8 + MOVD $PR_SET_PDEATHSIG, R0 + MOVD $SIGKILL, R1 + SVC + + CMN $4095, R0 + BCS error + + // If the parent already died before we called PR_SET_DEATHSIG then + // we'll have an unexpected PPID. + MOVD $SYS_GETPPID, R8 + SVC + + CMP R0, R7 + BNE parent_dead + + MOVD $SYS_GETPID, R8 + SVC + + CMP $0x0, R0 + BLT error + + // SIGSTOP to wait for attach. + // + // The SYSCALL instruction will be used for future syscall injection by + // thread.syscall. + MOVD $SYS_KILL, R8 + MOVD $SIGSTOP, R1 + SVC + // The tracer may "detach" and/or allow code execution here in three cases: + // + // 1. New (traced) stub threads are explicitly detached by the + // goroutine in newSubprocess. However, they are detached while in + // group-stop, so they do not execute code here. + // + // 2. If a tracer thread exits, it implicitly detaches from the stub, + // potentially allowing code execution here. However, the Go runtime + // never exits individual threads, so this case never occurs. + // + // 3. subprocess.createStub clones a new stub process that is untraced, + // thus executing this code. We setup the PDEATHSIG before SIGSTOPing + // ourselves for attach by the tracer. + // + // R7 has been updated with the expected PPID. + B begin + +error: + // Exit with -errno. + NEG R0, R0 + MOVD $SYS_EXIT, R8 + SVC + HLT + +parent_dead: + MOVD $SYS_EXIT, R8 + MOVD $1, R0 + SVC + HLT + +// stubCall calls the stub function at the given address with the given PPID. +// +// This is a distinct function because stub, above, may be mapped at any +// arbitrary location, and stub has a specific binary API (see above). +TEXT ·stubCall(SB),NOSPLIT,$0-16 + MOVD addr+0(FP), R0 + MOVD pid+8(FP), R7 + B (R0) diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 15e84735e..6bf7cd097 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -28,6 +28,16 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) +// Linux kernel errnos which "should never be seen by user programs", but will +// be revealed to ptrace syscall exit tracing. +// +// These constants are only used in subprocess.go. +const ( + ERESTARTSYS = syscall.Errno(512) + ERESTARTNOINTR = syscall.Errno(513) + ERESTARTNOHAND = syscall.Errno(514) +) + // globalPool exists to solve two distinct problems: // // 1) Subprocesses can't always be killed properly (see Release). @@ -282,7 +292,7 @@ func (t *thread) grabInitRegs() { if err := t.getRegs(&t.initRegs); err != nil { panic(fmt.Sprintf("ptrace get regs failed: %v", err)) } - t.initRegs.Rip -= initRegsRipAdjustment + t.adjustInitRegsRip() } // detach detaches from the thread. @@ -344,6 +354,9 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { continue // Spurious stop. } if stopSig == syscall.SIGTRAP { + if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { + t.dumpAndPanic("wait failed: the process exited") + } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) } diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index a70512913..4649a94a7 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -28,20 +28,13 @@ const ( // maximumUserAddress is the largest possible user address. maximumUserAddress = 0x7ffffffff000 + // stubInitAddress is the initial attempt link address for the stub. + stubInitAddress = 0x7fffffff0000 + // initRegsRipAdjustment is the size of the syscall instruction. initRegsRipAdjustment = 2 ) -// Linux kernel errnos which "should never be seen by user programs", but will -// be revealed to ptrace syscall exit tracing. -// -// These constants are used in subprocess.go. -const ( - ERESTARTSYS = syscall.Errno(512) - ERESTARTNOINTR = syscall.Errno(513) - ERESTARTNOHAND = syscall.Errno(514) -) - // resetSysemuRegs sets up emulation registers. // // This should be called prior to calling sysemu. @@ -139,3 +132,14 @@ func dumpRegs(regs *syscall.PtraceRegs) string { return m.String() } + +// adjustInitregsRip adjust the current register RIP value to +// be just before the system call instruction excution +func (t *thread) adjustInitRegsRip() { + t.initRegs.Rip -= initRegsRipAdjustment +} + +// Pass the expected PPID to the child via R15 when creating stub process +func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { + initregs.R15 = uint64(ppid) +} diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go new file mode 100644 index 000000000..bec884ba5 --- /dev/null +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -0,0 +1,126 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ptrace + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" +) + +const ( + // maximumUserAddress is the largest possible user address. + maximumUserAddress = 0xfffffffff000 + + // stubInitAddress is the initial attempt link address for the stub. + // Only support 48bits VA currently. + stubInitAddress = 0xffffffff0000 + + // initRegsRipAdjustment is the size of the svc instruction. + initRegsRipAdjustment = 4 +) + +// resetSysemuRegs sets up emulation registers. +// +// This should be called prior to calling sysemu. +func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) { +} + +// createSyscallRegs sets up syscall registers. +// +// This should be called to generate registers for a system call. +func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs { + // Copy initial registers (Pc, Sp, etc.). + regs := *initRegs + + // Set our syscall number. + // r8 for the syscall number. + // r0-r6 is used to store the parameters. + regs.Regs[8] = uint64(sysno) + if len(args) >= 1 { + regs.Regs[0] = args[0].Uint64() + } + if len(args) >= 2 { + regs.Regs[1] = args[1].Uint64() + } + if len(args) >= 3 { + regs.Regs[2] = args[2].Uint64() + } + if len(args) >= 4 { + regs.Regs[3] = args[3].Uint64() + } + if len(args) >= 5 { + regs.Regs[4] = args[4].Uint64() + } + if len(args) >= 6 { + regs.Regs[5] = args[5].Uint64() + } + + return regs +} + +// isSingleStepping determines if the registers indicate single-stepping. +func isSingleStepping(regs *syscall.PtraceRegs) bool { + // Refer to the ARM SDM D2.12.3: software step state machine + // return (regs.Pstate.SS == 1) && (MDSCR_EL1.SS == 1). + // + // Since the host Linux kernel will set MDSCR_EL1.SS on our behalf + // when we call a single-step ptrace command, we only need to check + // the Pstate.SS bit here. + return (regs.Pstate & arch.ARMTrapFlag) != 0 +} + +// updateSyscallRegs updates registers after finishing sysemu. +func updateSyscallRegs(regs *syscall.PtraceRegs) { + // No special work is necessary. + return +} + +// syscallReturnValue extracts a sensible return from registers. +func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) { + rval := int64(regs.Regs[0]) + if rval < 0 { + return 0, syscall.Errno(-rval) + } + return uintptr(rval), nil +} + +func dumpRegs(regs *syscall.PtraceRegs) string { + var m strings.Builder + + fmt.Fprintf(&m, "Registers:\n") + + for i := 0; i < 31; i++ { + fmt.Fprintf(&m, "\tRegs[%d]\t = %016x\n", i, regs.Regs[i]) + } + fmt.Fprintf(&m, "\tSp\t = %016x\n", regs.Sp) + fmt.Fprintf(&m, "\tPc\t = %016x\n", regs.Pc) + fmt.Fprintf(&m, "\tPstate\t = %016x\n", regs.Pstate) + + return m.String() +} + +// adjustInitregsRip adjust the current register RIP value to +// be just before the system call instruction excution +func (t *thread) adjustInitRegsRip() { + t.initRegs.Pc -= initRegsRipAdjustment +} + +// Pass the expected PPID to the child via X7 when creating stub process +func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { + initregs.Regs[7] = uint64(ppid) +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index 87ded0bbd..f09b0b3d0 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -284,7 +284,7 @@ func (s *subprocess) createStub() (*thread, error) { // Pass the expected PPID to the child via R15. regs := t.initRegs - regs.R15 = uint64(t.tgid) + initChildProcessPPID(®s, t.tgid) // Call fork in a subprocess. // diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index e977992f9..de6783fb0 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 linux +// +build linux +// +build amd64 arm64 package ptrace diff --git a/pkg/sentry/safemem/io.go b/pkg/sentry/safemem/io.go index 5c3d73eb7..f039a5c34 100644 --- a/pkg/sentry/safemem/io.go +++ b/pkg/sentry/safemem/io.go @@ -157,7 +157,8 @@ func (w ToIOWriter) Write(src []byte) (int, error) { } // FromIOReader implements Reader for an io.Reader by repeatedly invoking -// io.Reader.Read until it returns an error or partial read. +// io.Reader.Read until it returns an error or partial read. This is not +// thread-safe. // // FromIOReader will return a successful partial read iff Reader.Read does so. type FromIOReader struct { @@ -206,6 +207,58 @@ func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { return wbn, buf, rerr } +// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly +// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial +// read indicates an error. This is not thread-safe. +type FromIOReaderAt struct { + ReaderAt io.ReaderAt + Offset int64 +} + +// ReadToBlocks implements Reader.ReadToBlocks. +func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) { + var buf []byte + var done uint64 + for !dsts.IsEmpty() { + dst := dsts.Head() + var n int + var err error + n, buf, err = r.readToBlock(dst, buf) + done += uint64(n) + if n != dst.Len() { + return done, err + } + dsts = dsts.Tail() + if err != nil { + if dsts.IsEmpty() && err == io.EOF { + return done, nil + } + return done, err + } + } + return done, nil +} + +func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) { + // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require + // safecopy. + if !dst.NeedSafecopy() { + n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset) + r.Offset += int64(n) + return n, buf, err + } + if len(buf) < dst.Len() { + buf = make([]byte, dst.Len()) + } + rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset) + r.Offset += int64(rn) + wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) + if wberr != nil { + return wbn, buf, wberr + } + return wbn, buf, rerr +} + // FromIOWriter implements Writer for an io.Writer by repeatedly invoking // io.Writer.Write until it returns an error or partial write. // diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 2b03ea87c..3300f9a6b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/binary", "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 1f014f399..e927821e1 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", @@ -38,6 +39,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e57aed927..635042263 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -43,6 +43,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -290,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// GetAddress reads an sockaddr struct from the given address and converts it -// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 -// addresses. -func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and +// AF_INET6 addresses. +// +// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. +// +// AddressAndFamily returns an address, its family. +func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported + return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } // Get the rest of the fields based on the address family. @@ -309,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } // Drop the terminating NUL (if one exists) and everything after // it for filesystem (non-abstract) addresses. @@ -320,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse } return tcpip.FullAddress{ Addr: tcpip.Address(path), - }, nil + }, family, nil case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -333,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse Addr: bytesToIPAddress(a.Addr[:]), Port: ntohs(a.Port), } - return out, nil + return out, family, nil case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -349,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse if isLinkLocal(out.Addr) { out.NIC = tcpip.NICID(a.Scope_id) } - return out, nil + return out, family, nil case linux.AF_UNSPEC: - return tcpip.FullAddress{}, nil + return tcpip.FullAddress{}, family, nil default: - return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported } } @@ -428,6 +433,11 @@ func (i *ioSequencePayload) Size() int { return int(i.src.NumBytes()) } +// DropFirst drops the first n bytes from underlying src. +func (i *ioSequencePayload) DropFirst(n int) { + i.src = i.src.DropFirst(int(n)) +} + // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} @@ -476,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr, false /* strict */) + addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) if err != nil { return err } + if family == linux.AF_UNSPEC { + err := s.Endpoint.Disconnect() + if err == tcpip.ErrNotSupported { + return syserr.ErrAddressFamilyNotSupported + } + return syserr.TranslateNetstackError(err) + } // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -509,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr, true /* strict */) + addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -547,7 +564,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -574,7 +591,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer and write it to peer slice. @@ -624,7 +641,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for epsocket.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -655,6 +672,33 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( return val, nil } + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + info, err := netfilter.GetInfo(t, s.Endpoint, outPtr) + if err != nil { + return nil, err + } + return info, nil + + case linux.IPT_SO_GET_ENTRIES: + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + + entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen) + if err != nil { + return nil, err + } + return entries, nil + + } + } + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) } @@ -1028,7 +1072,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(linux.SockAddrInet).Addr, nil + return a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1658,7 +1702,7 @@ func isLinkLocal(addr tcpip.Address) bool { } // ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { switch family { case linux.AF_UNIX: var out linux.SockAddrUnix @@ -1674,15 +1718,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // address length is the max. Abstract and empty paths always return // the full exact length. if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return out, uint32(2 + l) + return &out, uint32(2 + l) } - return out, uint32(3 + l) + return &out, uint32(3 + l) case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1698,7 +1742,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) default: return nil, 0 } @@ -1706,7 +1750,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1718,7 +1762,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1791,7 +1835,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1839,7 +1883,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if err == nil { s.updateTimestamp() } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) @@ -1914,7 +1958,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -1990,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to, true /* strict */) + addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) if err != nil { return 0, err } @@ -1998,28 +2042,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] addr = &addrBuf } - v := buffer.NewView(int(src.NumBytes())) - - // Copy all the data into the buffer. - if _, err := src.CopyIn(t, v); err != nil { - return 0, syserr.FromError(err) - } - opts := tcpip.WriteOptions{ To: addr, More: flags&linux.MSG_MORE != 0, EndOfRecord: flags&linux.MSG_EOR != 0, } - n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) + v := &ioSequencePayload{t, src} + n, resCh, err := s.Endpoint.Write(v, opts) if resCh != nil { if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err) } - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) + n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= uintptr(len(v)) || dontWait) { + if err == nil && (n >= int64(v.Size()) || dontWait) { // Complete write. return int(n), nil } @@ -2033,18 +2071,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) - v.TrimFront(int(n)) + v.DropFirst(int(n)) total := n for { - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) - v.TrimFront(int(n)) + n, _, err = s.Endpoint.Write(v, opts) + v.DropFirst(int(n)) total += n if err != nil && err != tcpip.ErrWouldBlock && total == 0 { return 0, syserr.TranslateNetstackError(err) } - if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2252,19 +2290,19 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case syscall.SIOCGIFMAP: // Gets the hardware parameters of the device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFBRDADDR: // Gets the broadcast address of a device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFNETMASK: // Gets the network mask of a device. diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 8fe489c0e..7cf7ff735 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -18,7 +18,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -143,3 +146,57 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { func (s *Stack) Statistics(stat interface{}, arg string) error { return syserr.ErrEndpointOperation.ToError() } + +// RouteTable implements inet.Stack.RouteTable. +func (s *Stack) RouteTable() []inet.Route { + var routeTable []inet.Route + + for _, rt := range s.Stack.GetRouteTable() { + var family uint8 + switch len(rt.Destination.ID()) { + case header.IPv4AddressSize: + family = linux.AF_INET + case header.IPv6AddressSize: + family = linux.AF_INET6 + default: + log.Warningf("Unknown network protocol in route %+v", rt) + continue + } + + routeTable = append(routeTable, inet.Route{ + Family: family, + DstLen: uint8(rt.Destination.Prefix()), // The CIDR prefix for the destination. + + // Always return unspecified protocol since we have no notion of + // protocol for routes. + Protocol: linux.RTPROT_UNSPEC, + // Set statically to LINK scope for now. + // + // TODO(gvisor.dev/issue/595): Set scope for routes. + Scope: linux.RT_SCOPE_LINK, + Type: linux.RTN_UNICAST, + + DstAddr: []byte(rt.Destination.ID()), + OutputInterface: int32(rt.NIC), + GatewayAddr: []byte(rt.Gateway), + }) + } + + return routeTable +} + +// IPTables returns the stack's iptables. +func (s *Stack) IPTables() (iptables.IPTables, error) { + return s.Stack.IPTables(), nil +} + +// FillDefaultIPTables sets the stack's iptables to the default tables, which +// allow and do not modify all traffic. +func (s *Stack) FillDefaultIPTables() { + netfilter.FillDefaultIPTables(s.Stack) +} + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() { + s.Stack.Resume() +} diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7f69406b7..92beb1bcf 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { - var peerAddr []byte +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + var peerAddr linux.SockAddr + var peerAddrBuf []byte var peerAddrlen uint32 var peerAddrPtr *byte var peerAddrlenPtr *uint32 if peerRequested { - peerAddr = make([]byte, sizeofSockaddr) - peerAddrlen = uint32(len(peerAddr)) - peerAddrPtr = &peerAddr[0] + peerAddrBuf = make([]byte, sizeofSockaddr) + peerAddrlen = uint32(len(peerAddrBuf)) + peerAddrPtr = &peerAddrBuf[0] peerAddrlenPtr = &peerAddrlen } @@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } if peerRequested { - peerAddr = peerAddr[:peerAddrlen] + peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen]) } if syscallErr != nil { return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) @@ -272,7 +273,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) { +func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary @@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr []byte + var senderAddr linux.SockAddr + var senderAddrBuf []byte if senderRequested { - senderAddr = make([]byte, sizeofSockaddr) + senderAddrBuf = make([]byte, sizeofSockaddr) } var msgFlags int @@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if dsts.NumBlocks() == 1 { // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr) + return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) } iovs := iovecsFromBlockSeq(dsts) @@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Iov: &iovs[0], Iovlen: uint64(len(iovs)), } - if len(senderAddr) != 0 { - msg.Name = &senderAddr[0] - msg.Namelen = uint32(len(senderAddr)) + if len(senderAddrBuf) != 0 { + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(len(senderAddrBuf)) } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } - senderAddr = senderAddr[:msg.Namelen] + senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) return n, nil }) @@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We don't allow control messages. msgFlags &^= linux.MSG_CTRUNC - return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err) + if senderRequested { + senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) + } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index 6c69ba9c7..e69ec38c2 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -18,10 +18,12 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -91,25 +93,25 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) { } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) if errno != 0 { return nil, 0, syserr.FromError(errno) } - return addr[:addrlen], addrlen, nil + return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) if errno != 0 { return nil, 0, syserr.FromError(errno) } - return addr[:addrlen], addrlen, nil + return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) { diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index cc1f66fa1..3a4fdec47 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -46,6 +46,7 @@ type Stack struct { // Stack is immutable. interfaces map[int32]inet.Interface interfaceAddrs map[int32][]inet.InterfaceAddr + routes []inet.Route supportsIPv6 bool tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize @@ -66,6 +67,10 @@ func (s *Stack) Configure() error { return err } + if err := addHostRoutes(s); err != nil { + return err + } + if _, err := os.Stat("/proc/net/if_inet6"); err == nil { s.supportsIPv6 = true } @@ -161,6 +166,60 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return nil } +// ExtractHostRoutes populates the given routes slice with the data from the +// host route table. +func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) { + var routes []inet.Route + for _, routeMsg := range routeMsgs { + if routeMsg.Header.Type != syscall.RTM_NEWROUTE { + continue + } + + var ifRoute syscall.RtMsg + binary.Unmarshal(routeMsg.Data[:syscall.SizeofRtMsg], usermem.ByteOrder, &ifRoute) + inetRoute := inet.Route{ + Family: ifRoute.Family, + DstLen: ifRoute.Dst_len, + SrcLen: ifRoute.Src_len, + TOS: ifRoute.Tos, + Table: ifRoute.Table, + Protocol: ifRoute.Protocol, + Scope: ifRoute.Scope, + Type: ifRoute.Type, + Flags: ifRoute.Flags, + } + + // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the + // syscall.NetlinkMessage.Header.Type and skip the struct rtmsg + // accordingly. + attrs, err := syscall.ParseNetlinkRouteAttr(&routeMsg) + if err != nil { + return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid rtattrs: %v", err) + } + + for _, attr := range attrs { + switch attr.Attr.Type { + case syscall.RTA_DST: + inetRoute.DstAddr = attr.Value + case syscall.RTA_SRC: + inetRoute.SrcAddr = attr.Value + case syscall.RTA_GATEWAY: + inetRoute.GatewayAddr = attr.Value + case syscall.RTA_OIF: + expected := int(binary.Size(inetRoute.OutputInterface)) + if len(attr.Value) != expected { + return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected) + } + binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface) + } + } + + routes = append(routes, inetRoute) + } + + return routes, nil +} + func addHostInterfaces(s *Stack) error { links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK) if err != nil { @@ -175,6 +234,20 @@ func addHostInterfaces(s *Stack) error { return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs) } +func addHostRoutes(s *Stack) error { + routes, err := doNetlinkRouteRequest(syscall.RTM_GETROUTE) + if err != nil { + return fmt.Errorf("RTM_GETROUTE failed: %v", err) + } + + s.routes, err = ExtractHostRoutes(routes) + if err != nil { + return err + } + + return nil +} + func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) { data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC) if err != nil { @@ -202,12 +275,20 @@ func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) { // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { - return s.interfaces + interfaces := make(map[int32]inet.Interface) + for k, v := range s.interfaces { + interfaces[k] = v + } + return interfaces } // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { - return s.interfaceAddrs + addrs := make(map[int32][]inet.InterfaceAddr) + for k, v := range s.interfaceAddrs { + addrs[k] = append([]inet.InterfaceAddr(nil), v...) + } + return addrs } // SupportsIPv6 implements inet.Stack.SupportsIPv6. @@ -249,3 +330,11 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { func (s *Stack) Statistics(stat interface{}, arg string) error { return syserror.EOPNOTSUPP } + +// RouteTable implements inet.Stack.RouteTable. +func (s *Stack) RouteTable() []inet.Route { + return append([]inet.Route(nil), s.routes...) +} + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() {} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD new file mode 100644 index 000000000..354a0d6ee --- /dev/null +++ b/pkg/sentry/socket/netfilter/BUILD @@ -0,0 +1,24 @@ +package(licenses = ["notice"]) + +load("//tools/go_stateify:defs.bzl", "go_library") + +go_library( + name = "netfilter", + srcs = [ + "netfilter.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter", + # This target depends on netstack and should only be used by epsocket, + # which is allowed to depend on netstack. + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/kernel", + "//pkg/sentry/usermem", + "//pkg/syserr", + "//pkg/tcpip", + "//pkg/tcpip/iptables", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go new file mode 100644 index 000000000..9f87c32f1 --- /dev/null +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -0,0 +1,286 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package netfilter helps the sentry interact with netstack's netfilter +// capabilities. +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// errorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const errorTargetName = "ERROR" + +// metadata is opaque to netstack. It holds data that we need to translate +// between Linux's and netstack's iptables representations. +type metadata struct { + HookEntry [linux.NF_INET_NUMHOOKS]uint32 + Underflow [linux.NF_INET_NUMHOOKS]uint32 + NumEntries uint32 + Size uint32 +} + +// GetInfo returns information about iptables. +func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { + // Read in the struct and table name. + var info linux.IPTGetinfo + if _, err := t.CopyIn(outPtr, &info); err != nil { + return linux.IPTGetinfo{}, syserr.FromError(err) + } + + // Find the appropriate table. + table, err := findTable(ep, info.TableName()) + if err != nil { + return linux.IPTGetinfo{}, err + } + + // Get the hooks that apply to this table. + info.ValidHooks = table.ValidHooks() + + // Grab the metadata struct, which is used to store information (e.g. + // the number of entries) that applies to the user's encoding of + // iptables, but not netstack's. + metadata := table.Metadata().(metadata) + + // Set values from metadata. + info.HookEntry = metadata.HookEntry + info.Underflow = metadata.Underflow + info.NumEntries = metadata.NumEntries + info.Size = metadata.Size + + return info, nil +} + +// GetEntries returns netstack's iptables rules encoded for the iptables tool. +func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { + // Read in the struct and table name. + var userEntries linux.IPTGetEntries + if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + return linux.KernelIPTGetEntries{}, syserr.FromError(err) + } + + // Find the appropriate table. + table, err := findTable(ep, userEntries.TableName()) + if err != nil { + return linux.KernelIPTGetEntries{}, err + } + + // Convert netstack's iptables rules to something that the iptables + // tool can understand. + entries, _, err := convertNetstackToBinary(userEntries.TableName(), table) + if err != nil { + return linux.KernelIPTGetEntries{}, err + } + if binary.Size(entries) > uintptr(outLen) { + return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument + } + + return entries, nil +} + +func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) { + ipt, err := ep.IPTables() + if err != nil { + return iptables.Table{}, syserr.FromError(err) + } + table, ok := ipt.Tables[tableName] + if !ok { + return iptables.Table{}, syserr.ErrInvalidArgument + } + return table, nil +} + +// FillDefaultIPTables sets stack's IPTables to the default tables and +// populates them with metadata. +func FillDefaultIPTables(stack *stack.Stack) { + ipt := iptables.DefaultTables() + + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range ipt.Tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + ipt.Tables[name] = table + } + + stack.SetIPTables(ipt) +} + +// convertNetstackToBinary converts the iptables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a bit, reading some offsets, +// jumping to those offsets, parsing again, etc. +func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) { + // Return values. + var entries linux.KernelIPTGetEntries + var meta metadata + + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(name) { + return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument + } + copy(entries.Name[:], name) + + // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of + // these chains ends with an unconditional policy entry. + for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ { + chain, ok := table.BuiltinChains[hook] + if !ok { + // This table doesn't support this hook. + continue + } + + // Sanity check. + if len(chain.Rules) < 1 { + return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument + } + + for ruleIdx, rule := range chain.Rules { + // If this is the first rule of a builtin chain, set + // the metadata hook entry point. + if ruleIdx == 0 { + meta.HookEntry[hook] = entries.Size + } + + // Each rule corresponds to an entry. + entry := linux.KernelIPTEntry{ + IPTEntry: linux.IPTEntry{ + NextOffset: linux.SizeOfIPTEntry, + TargetOffset: linux.SizeOfIPTEntry, + }, + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + entry.Elems = append(entry.Elems, serialized...) + entry.NextOffset += uint16(len(serialized)) + entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + entry.Elems = append(entry.Elems, serialized...) + entry.NextOffset += uint16(len(serialized)) + + // The underflow rule is the last rule in the chain, + // and is an unconditional rule (i.e. it matches any + // packet). This is enforced when saving iptables. + if ruleIdx == len(chain.Rules)-1 { + meta.Underflow[hook] = entries.Size + } + + entries.Size += uint32(entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + meta.NumEntries++ + } + + } + + // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of + // these starts with an error node holding the chain's name and ends + // with an unconditional return. + + // Lastly, each table ends with an unconditional error target rule as + // its final entry. + errorEntry := linux.KernelIPTEntry{ + IPTEntry: linux.IPTEntry{ + NextOffset: linux.SizeOfIPTEntry, + TargetOffset: linux.SizeOfIPTEntry, + }, + } + var errorTarget linux.XTErrorTarget + errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget + copy(errorTarget.ErrorName[:], errorTargetName) + copy(errorTarget.Target.Name[:], errorTargetName) + + // Serialize and add it to the list of entries. + errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget) + serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget) + errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...) + errorEntry.NextOffset += uint16(len(serializedErrorTarget)) + + entries.Size += uint32(errorEntry.NextOffset) + entries.Entrytable = append(entries.Entrytable, errorEntry) + meta.NumEntries++ + meta.Size = entries.Size + + return entries, meta, nil +} + +func marshalMatcher(matcher iptables.Matcher) []byte { + switch matcher.(type) { + default: + // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so + // any call to marshalMatcher will panic. + panic(fmt.Errorf("unknown matcher of type %T", matcher)) + } +} + +func marshalTarget(target iptables.Target) []byte { + switch target.(type) { + case iptables.UnconditionalAcceptTarget: + return marshalUnconditionalAcceptTarget() + default: + panic(fmt.Errorf("unknown target of type %T", target)) + } +} + +func marshalUnconditionalAcceptTarget() []byte { + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + Verdict: translateStandardVerdict(iptables.Accept), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +// translateStandardVerdict translates verdicts the same way as the iptables +// tool. +func translateStandardVerdict(verdict iptables.Verdict) int32 { + switch verdict { + case iptables.Accept: + return -linux.NF_ACCEPT - 1 + case iptables.Drop: + return -linux.NF_DROP - 1 + case iptables.Queue: + return -linux.NF_QUEUE - 1 + case iptables.Return: + return linux.NF_RETURN + case iptables.Jump: + // TODO(gvisor.dev/issue/170): Support Jump. + panic("Jump isn't supported yet") + default: + panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) + } +} diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index fb1ff329c..cc70ac237 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -110,7 +110,7 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader m.PutAttr(linux.IFLA_ADDRESS, mac) m.PutAttr(linux.IFLA_BROADCAST, brd) - // TODO(b/68878065): There are many more attributes. + // TODO(gvisor.dev/issue/578): There are many more attributes. } return nil @@ -151,13 +151,69 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) - // TODO(b/68878065): There are many more attributes. + // TODO(gvisor.dev/issue/578): There are many more attributes. } } return nil } +// dumpRoutes handles RTM_GETROUTE + NLM_F_DUMP requests. +func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { + // RTM_GETROUTE dump requests need not contain anything more than the + // netlink header and 1 byte protocol family common to all + // NETLINK_ROUTE requests. + + // We always send back an NLMSG_DONE. + ms.Multi = true + + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network routes. + return nil + } + + for _, rt := range stack.RouteTable() { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWROUTE, + }) + + m.Put(linux.RouteMessage{ + Family: rt.Family, + DstLen: rt.DstLen, + SrcLen: rt.SrcLen, + TOS: rt.TOS, + + // Always return the main table since we don't have multiple + // routing tables. + Table: linux.RT_TABLE_MAIN, + Protocol: rt.Protocol, + Scope: rt.Scope, + Type: rt.Type, + + Flags: rt.Flags, + }) + + m.PutAttr(254, []byte{123}) + if rt.DstLen > 0 { + m.PutAttr(linux.RTA_DST, rt.DstAddr) + } + if rt.SrcLen > 0 { + m.PutAttr(linux.RTA_SRC, rt.SrcAddr) + } + if rt.OutputInterface != 0 { + m.PutAttr(linux.RTA_OIF, rt.OutputInterface) + } + if len(rt.GatewayAddr) > 0 { + m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr) + } + + // TODO(gvisor.dev/issue/578): There are many more attributes. + } + + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { // All messages start with a 1 byte protocol family. @@ -186,6 +242,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH return p.dumpLinks(ctx, hdr, data, ms) case linux.RTM_GETADDR: return p.dumpAddrs(ctx, hdr, data, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, hdr, data, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index f3d6c1e9b..d0aab293d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -271,7 +271,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr } // Accept implements socket.Socket.Accept. -func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Netlink sockets never support accept. return 0, nil, 0, syserr.ErrNotSupported } @@ -289,7 +289,7 @@ func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) { +func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -379,11 +379,11 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } // GetSockName implements socket.Socket.GetSockName. -func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { s.mu.Lock() defer s.mu.Unlock() - sa := linux.SockAddrNetlink{ + sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: uint32(s.portID), } @@ -391,8 +391,8 @@ func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // GetPeerName implements socket.Socket.GetPeerName. -func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { - sa := linux.SockAddrNetlink{ +func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { + sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, // TODO(b/68878065): Support non-kernel peers. For now the peer // must be the kernel. @@ -402,8 +402,8 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { - from := linux.SockAddrNetlink{ +func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + from := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, } @@ -511,6 +511,19 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error return nil } +func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.NLMSG_ERROR, + }) + + m.Put(linux.NetlinkErrorMessage{ + Error: int32(-err.ToLinux().Number()), + Header: hdr, + }) + return nil + +} + // processMessages handles each message in buf, passing it to the protocol // handler for final handling. func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { @@ -545,14 +558,20 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error continue } + ms := NewMessageSet(s.portID, hdr.Seq) + var err *syserr.Error // TODO(b/68877377): ACKs not supported yet. if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { - return syserr.ErrNotSupported - } + err = syserr.ErrNotSupported + } else { - ms := NewMessageSet(s.portID, hdr.Seq) - if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil { - return err + err = s.protocol.ProcessMessage(ctx, hdr, data, ms) + } + if err != nil { + ms = NewMessageSet(s.portID, hdr.Seq) + if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil { + return err + } } if err := s.sendResponse(ctx, ms); err != nil { diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD index a536f2e44..a3585e10d 100644 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ b/pkg/sentry/socket/rpcinet/notifier/BUILD @@ -6,10 +6,11 @@ go_library( name = "notifier", srcs = ["notifier.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier", - visibility = ["//pkg/sentry:internal"], + visibility = ["//:sandbox"], deps = [ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", "//pkg/sentry/socket/rpcinet/conn", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go index aa157dd51..7efe4301f 100644 --- a/pkg/sentry/socket/rpcinet/notifier/notifier.go +++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go @@ -20,6 +20,7 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" "gvisor.dev/gvisor/pkg/waiter" @@ -76,7 +77,7 @@ func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error { } e := pb.EpollEvent{ - Events: mask.ToLinux() | -syscall.EPOLLET, + Events: mask.ToLinux() | unix.EPOLLET, Fd: fd, } diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index ccaaddbfc..ddb76d9d4 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { payload, se := rpcAccept(t, s.fd, peerRequested) // Check if we need to block. @@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, NonBlocking: flags&linux.SOCK_NONBLOCK != 0, } file := fs.NewFile(t, dirent, fileFlags, &socketOperations{ + family: s.family, + stype: s.stype, + protocol: s.protocol, wq: &wq, fd: payload.Fd, rpcConn: s.rpcConn, @@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, t.Kernel().RecordSocket(file) if peerRequested { - return fd, payload.Address.Address, payload.Address.Length, nil + return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil } return fd, nil, 0, nil @@ -395,7 +398,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) { +func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed // within the sentry. if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { @@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetPeerNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetSockNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) { @@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ Fd: s.fd, Length: uint32(dst.NumBytes()), @@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, err @@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { return 0, 0, nil, 0, socket.ControlMessages{}, err diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 49bd3a220..5dcb6b455 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -30,6 +30,7 @@ import ( type Stack struct { interfaces map[int32]inet.Interface interfaceAddrs map[int32][]inet.InterfaceAddr + routes []inet.Route rpcConn *conn.RPCConnection notifier *notifier.Notifier } @@ -69,6 +70,16 @@ func NewStack(fd int32) (*Stack, error) { return nil, e } + routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE) + if err != nil { + return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err) + } + + stack.routes, e = hostinet.ExtractHostRoutes(routes) + if e != nil { + return nil, e + } + return stack, nil } @@ -89,12 +100,20 @@ func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) { // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { - return s.interfaces + interfaces := make(map[int32]inet.Interface) + for k, v := range s.interfaces { + interfaces[k] = v + } + return interfaces } // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { - return s.interfaceAddrs + addrs := make(map[int32][]inet.InterfaceAddr) + for k, v := range s.interfaceAddrs { + addrs[k] = append([]inet.InterfaceAddr(nil), v...) + } + return addrs } // SupportsIPv6 implements inet.Stack.SupportsIPv6. @@ -138,3 +157,11 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { func (s *Stack) Statistics(stat interface{}, arg string) error { return syserr.ErrEndpointOperation.ToError() } + +// RouteTable implements inet.Stack.RouteTable. +func (s *Stack) RouteTable() []inet.Route { + return append([]inet.Route(nil), s.routes...) +} + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() {} diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 0efa58a58..8c250c325 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -20,8 +20,10 @@ package socket import ( "fmt" "sync/atomic" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -52,7 +54,7 @@ type Socket interface { // Accept implements the accept4(2) linux syscall. // Returns fd, real peer address length and error. Real peer address // length is only set if len(peer) > 0. - Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) + Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) // Bind implements the bind(2) linux syscall. Bind(t *kernel.Task, sockaddr []byte) *syserr.Error @@ -64,7 +66,7 @@ type Socket interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error @@ -73,13 +75,13 @@ type Socket interface { // // addrLen is the address length to be returned to the application, not // necessarily the actual length of the address. - GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error) + GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) // GetPeerName implements the getpeername(2) linux syscall. // // addrLen is the address length to be returned to the application, not // necessarily the actual length of the address. - GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error) + GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) // RecvMsg implements the recvmsg(2) linux syscall. // @@ -92,7 +94,7 @@ type Socket interface { // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately. // // If err != nil, the recv was not successful. - RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) + RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take // ownership of the ControlMessage on error. @@ -340,3 +342,31 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { t.Kernel().EmitUnimplementedEvent(t) } } + +// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of +// the ABI socket address types. +// +// Precondition: data must be long enough to represent a socket address of the +// given family. +func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { + switch family { + case syscall.AF_INET: + var addr linux.SockAddrInet + binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_INET6: + var addr linux.SockAddrInet6 + binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_UNIX: + var addr linux.SockAddrUnix + binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_NETLINK: + var addr linux.SockAddrNetlink + binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr) + return &addr + default: + panic(fmt.Sprintf("Unsupported socket family %v", family)) + } +} diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 760c7beab..2ec1a662d 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -62,7 +62,7 @@ type EndpointReader struct { Creds bool // NumRights is the number of SCM_RIGHTS FDs requested. - NumRights uintptr + NumRights int // Peek indicates that the data should not be consumed from the // endpoint. @@ -70,7 +70,7 @@ type EndpointReader struct { // MsgSize is the size of the message that was read from. For stream // sockets, it is the amount read. - MsgSize uintptr + MsgSize int64 // From, if not nil, will be set with the address read from. From *tcpip.FullAddress diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 73d2df15d..4bd15808a 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -436,7 +436,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. -func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { +func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. if e.stype == linux.SOCK_STREAM && to != nil { diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index c7f7c5b16..0322dec0b 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -99,7 +99,7 @@ func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (Con // SendMsg writes data and a control message to the specified endpoint. // This method does not block if the data cannot be written. -func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { +func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) { if to == nil { return e.baseEndpoint.SendMsg(ctx, data, c, nil) } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 7fb9cb1e0..2b0ad6395 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -121,13 +121,13 @@ type Endpoint interface { // CMTruncated indicates that the numRights hint was used to receive fewer // than the total available SCM_RIGHTS FDs. Additional truncation may be // required by the caller. - RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error) + RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, err *syserr.Error) // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. // // SendMsg does not take ownership of any of its arguments on error. - SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error) + SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (int64, *syserr.Error) // Connect connects this endpoint directly to another. // @@ -291,7 +291,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -331,7 +331,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -344,13 +344,13 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err } src := []byte(m.Data) - var copied uintptr + var copied int64 for i := 0; i < len(data) && len(src) > 0; i++ { n := copy(data[i], src) - copied += uintptr(n) + copied += int64(n) src = src[n:] } - return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil + return copied, int64(len(m.Data)), m.Control, false, m.Address, notify, nil } // RecvNotify implements Receiver.RecvNotify. @@ -401,11 +401,11 @@ type streamQueueReceiver struct { addr tcpip.FullAddress } -func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) { - var copied uintptr +func vecCopy(data [][]byte, buf []byte) (int64, [][]byte, []byte) { + var copied int64 for len(data) > 0 && len(buf) > 0 { n := copy(data[0], buf) - copied += uintptr(n) + copied += int64(n) buf = buf[n:] data[0] = data[0][n:] if len(data[0]) == 0 { @@ -443,7 +443,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -464,7 +464,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint q.addr = m.Address } - var copied uintptr + var copied int64 if peek { // Don't consume control message if we are peeking. c := q.control.Clone() @@ -531,7 +531,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint break } - var cpd uintptr + var cpd int64 cpd, data, q.buffer = vecCopy(data, q.buffer) copied += cpd @@ -569,7 +569,7 @@ type ConnectedEndpoint interface { // // syserr.ErrWouldBlock can be returned along with a partial write if // the caller should block to send the rest of the data. - Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *syserr.Error) + Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) // SendNotify notifies the ConnectedEndpoint of a successful Send. This // must not be called while holding any endpoint locks. @@ -637,7 +637,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) } // Send implements ConnectedEndpoint.Send. -func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) { +func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { var l int64 for _, d := range data { l += int64(len(d)) @@ -665,7 +665,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, } l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate) - return uintptr(l), notify, err + return int64(l), notify, err } // SendNotify implements ConnectedEndpoint.SendNotify. @@ -781,7 +781,7 @@ func (e *baseEndpoint) Connected() bool { } // RecvMsg reads data and a control message from the endpoint. -func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) { +func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (int64, int64, ControlMessages, bool, *syserr.Error) { e.Lock() if e.receiver == nil { @@ -807,7 +807,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. -func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { +func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) { e.Lock() if !e.Connected() { e.Unlock() diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index eb262ecaf..0d0cb68df 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */) + addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } @@ -137,7 +137,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -149,7 +149,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -166,7 +166,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } @@ -199,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, err := s.ep.Accept() if err != nil { @@ -223,7 +223,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer. @@ -505,7 +505,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -535,7 +535,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Ctx: t, Endpoint: s.ep, Creds: wantCreds, - NumRights: uintptr(numRights), + NumRights: numRights, Peek: peek, } if senderRequested { @@ -543,7 +543,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } var total int64 if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { - var from interface{} + var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) @@ -578,7 +578,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags for { if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { - var from interface{} + var from linux.SockAddr var fromLen uint32 if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index f297ef3b7..88765f4d6 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", + "//pkg/sentry/time", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 026549756..9eb626b76 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -104,7 +105,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error { +func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -114,5 +115,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error { previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n) + return k.LoadFrom(r, n, clocks) } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 386b40af7..f779186ad 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, err := epsocket.GetAddress(int(family), b, true /* strict */) + fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 264301bfa..1d9018c96 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -91,6 +91,10 @@ func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op strin // TODO(gvisor.dev/issue/161): In some cases SIGPIPE should // also be sent to the application. return nil + case syserror.ENOSPC: + // Similar to EPIPE. Return what we wrote this time, and let + // ENOSPC be returned on the next call. + return nil case syserror.ECONNRESET: // For TCP sendfile connections, we may have a reset. But we // should just return n as the result. diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 51db2d8f7..ed996ba51 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -30,8 +30,7 @@ import ( const _AUDIT_ARCH_X86_64 = 0xc000003e // AMD64 is a table of Linux amd64 syscall API with the corresponding syscall -// numbers from Linux 4.4. The entries commented out are those syscalls we -// don't currently support. +// numbers from Linux 4.4. var AMD64 = &kernel.SyscallTable{ OS: abi.Linux, Arch: arch.AMD64, diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 4a2b9f061..65b4a227b 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -107,19 +107,20 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // copyOutEvents copies epoll events from the kernel to user memory. func copyOutEvents(t *kernel.Task, addr usermem.Addr, e []epoll.Event) error { const itemLen = 12 - if _, ok := addr.AddLength(uint64(len(e)) * itemLen); !ok { + buffLen := len(e) * itemLen + if _, ok := addr.AddLength(uint64(buffLen)); !ok { return syserror.EFAULT } - b := t.CopyScratchBuffer(itemLen) + b := t.CopyScratchBuffer(buffLen) for i := range e { - usermem.ByteOrder.PutUint32(b[0:], e[i].Events) - usermem.ByteOrder.PutUint32(b[4:], uint32(e[i].Data[0])) - usermem.ByteOrder.PutUint32(b[8:], uint32(e[i].Data[1])) - if _, err := t.CopyOutBytes(addr, b); err != nil { - return err - } - addr += itemLen + usermem.ByteOrder.PutUint32(b[i*itemLen:], e[i].Events) + usermem.ByteOrder.PutUint32(b[i*itemLen+4:], uint32(e[i].Data[0])) + usermem.ByteOrder.PutUint32(b[i*itemLen+8:], uint32(e[i].Data[1])) + } + + if _, err := t.CopyOutBytes(addr, b); err != nil { + return err } return nil diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index 63e2c5a5d..912cbe4ff 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -120,7 +120,7 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent Ino: attr.InodeID, Off: offset, }, - Typ: toType(attr.Type), + Typ: fs.ToDirentType(attr.Type), }, Name: []byte(name), } @@ -142,28 +142,6 @@ func smallestDirent64(a arch.Context) uint { return uint(binary.Size(d.Hdr)) + a.Width() } -// toType converts an fs.InodeOperationsInfo to a linux dirent typ field. -func toType(nodeType fs.InodeType) uint8 { - switch nodeType { - case fs.RegularFile, fs.SpecialFile: - return linux.DT_REG - case fs.Symlink: - return linux.DT_LNK - case fs.Directory, fs.SpecialDirectory: - return linux.DT_DIR - case fs.Pipe: - return linux.DT_FIFO - case fs.CharacterDevice: - return linux.DT_CHR - case fs.BlockDevice: - return linux.DT_BLK - case fs.Socket: - return linux.DT_SOCK - default: - return linux.DT_UNKNOWN - } -} - // padRec pads the name field until the rec length is a multiple of the width, // which must be a power of 2. It returns the padded rec length. func (d *dirent) padRec(width int) uint16 { diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go index 9080a10c3..8c13e2d82 100644 --- a/pkg/sentry/syscalls/linux/sys_mount.go +++ b/pkg/sentry/syscalls/linux/sys_mount.go @@ -109,9 +109,17 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.EINVAL } - return 0, nil, fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + if err := fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + // Mount will take a reference on rootInode if successful. return t.MountNamespace().Mount(t, d, rootInode) - }) + }); err != nil { + // Something went wrong. Drop our ref on rootInode before + // returning the error. + rootInode.DecRef() + return 0, nil, err + } + + return 0, nil, nil } // Umount2 implements Linux syscall umount2(2). diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index b2474e60d..3ab54271c 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -191,7 +191,6 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Preadv2 implements linux syscall preadv2(2). -// TODO(b/120162627): Implement RWF_HIPRI functionality. func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { // While the syscall is // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) @@ -228,6 +227,8 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Check flags field. + // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is + // accepted as a valid flag argument for preadv2. if flags&^linux.RWF_VALID != 0 { return 0, nil, syserror.EOPNOTSUPP } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index fa568a660..3bac4d90d 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -460,7 +460,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Call syscall implementation then copy both value and value len out. - v, e := getSockOpt(t, s, int(level), int(name), int(optLen)) + v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen)) if e != nil { return 0, nil, e.ToError() } @@ -483,7 +483,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -505,7 +505,7 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name, len int) (interfac } } - return s.GetSockOpt(t, level, name, len) + return s.GetSockOpt(t, level, name, optValAddr, len) } // SetSockOpt implements the linux syscall setsockopt(2). diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index a7c98efcb..8a98fedcb 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -91,22 +91,29 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } // Get files. + inFile := t.GetFile(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + + if !inFile.Flags().Read { + return 0, nil, syserror.EBADF + } + outFile := t.GetFile(outFD) if outFile == nil { return 0, nil, syserror.EBADF } defer outFile.DecRef() - inFile := t.GetFile(inFD) - if inFile == nil { + if !outFile.Flags().Write { return 0, nil, syserror.EBADF } - defer inFile.DecRef() - // Verify that the outfile Append flag is not set. Note that fs.Splice - // itself validates that the output file is writable. + // Verify that the outfile Append flag is not set. if outFile.Flags().Append { - return 0, nil, syserror.EBADF + return 0, nil, syserror.EINVAL } // Verify that we have a regular infile. This is a requirement; the @@ -207,6 +214,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ESPIPE } if outOffset != 0 { + if !outFile.Flags().Pwrite { + return 0, nil, syserror.EINVAL + } + var offset int64 if _, err := t.CopyIn(outOffset, &offset); err != nil { return 0, nil, err @@ -220,6 +231,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ESPIPE } if inOffset != 0 { + if !inFile.Flags().Pread { + return 0, nil, syserror.EINVAL + } + var offset int64 if _, err := t.CopyIn(inOffset, &offset); err != nil { return 0, nil, err diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 595eb9155..8ab7ffa25 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -96,7 +96,7 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Load the new TaskContext. maxTraversals := uint(linux.MaxSymlinkTraversals) - tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, argv, envv, t.Arch().FeatureSet()) + tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet()) if se != nil { return 0, nil, se.ToError() } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 5278c96a6..27cd2c336 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Pwritev2 implements linux syscall pwritev2(2). -// TODO(b/120162627): Implement RWF_HIPRI functionality. // TODO(b/120161091): Implement O_SYNC and D_SYNC functionality. func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { // While the syscall is @@ -227,6 +226,8 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, syserror.ESPIPE } + // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is + // accepted as a valid flag argument for pwritev2. if flags&^linux.RWF_VALID != 0 { return uintptr(flags), nil, syserror.EOPNOTSUPP } diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 4de6c41cf..0f247bf77 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -18,6 +18,7 @@ go_library( "permissions.go", "resolving_path.go", "syscalls.go", + "testutil.go", "vfs.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/vfs", @@ -40,7 +41,16 @@ go_test( name = "vfs_test", size = "small", srcs = [ + "file_description_impl_util_test.go", "mount_test.go", ], embed = [":vfs"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/syserror", + ], ) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 486893e70..ba230da72 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -15,6 +15,10 @@ package vfs import ( + "bytes" + "io" + "sync" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +28,16 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// The following design pattern is strongly recommended for filesystem +// implementations to adapt: +// - Have a local fileDescription struct (containing FileDescription) which +// embeds FileDescriptionDefaultImpl and overrides the default methods +// which are common to all fd implementations for that for that filesystem +// like StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc. +// - This should be embedded in all file description implementations as the +// first field by value. +// - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl. + // FileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl to obtain implementations of many FileDescriptionImpl // methods with default behavior analogous to Linux's. @@ -115,11 +129,8 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain -// implementations of non-directory I/O methods that return EISDIR, and -// implementations of other methods consistent with FileDescriptionDefaultImpl. -type DirectoryFileDescriptionDefaultImpl struct { - FileDescriptionDefaultImpl -} +// implementations of non-directory I/O methods that return EISDIR. +type DirectoryFileDescriptionDefaultImpl struct{} // PRead implements FileDescriptionImpl.PRead. func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { @@ -140,3 +151,104 @@ func (DirectoryFileDescriptionDefaultImpl) PWrite(ctx context.Context, src userm func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { return 0, syserror.EISDIR } + +// DynamicBytesFileDescriptionImpl may be embedded by implementations of +// FileDescriptionImpl that represent read-only regular files whose contents +// are backed by a bytes.Buffer that is regenerated when necessary, consistent +// with Linux's fs/seq_file.c:single_open(). +// +// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first +// use. +type DynamicBytesFileDescriptionImpl struct { + data DynamicBytesSource // immutable + mu sync.Mutex // protects the following fields + buf bytes.Buffer + off int64 + lastRead int64 // offset at which the last Read, PRead, or Seek ended +} + +// DynamicBytesSource represents a data source for a +// DynamicBytesFileDescriptionImpl. +type DynamicBytesSource interface { + // Generate writes the file's contents to buf. + Generate(ctx context.Context, buf *bytes.Buffer) error +} + +// SetDataSource must be called exactly once on fd before first use. +func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) { + fd.data = data +} + +// Preconditions: fd.mu must be locked. +func (fd *DynamicBytesFileDescriptionImpl) preadLocked(ctx context.Context, dst usermem.IOSequence, offset int64, opts *ReadOptions) (int64, error) { + // Regenerate the buffer if it's empty, or before pread() at a new offset. + // Compare fs/seq_file.c:seq_read() => traverse(). + switch { + case offset != fd.lastRead: + fd.buf.Reset() + fallthrough + case fd.buf.Len() == 0: + if err := fd.data.Generate(ctx, &fd.buf); err != nil { + fd.buf.Reset() + // fd.off is not updated in this case. + fd.lastRead = 0 + return 0, err + } + } + bs := fd.buf.Bytes() + if offset >= int64(len(bs)) { + return 0, io.EOF + } + n, err := dst.CopyOut(ctx, bs[offset:]) + fd.lastRead = offset + int64(n) + return int64(n), err +} + +// PRead implements FileDescriptionImpl.PRead. +func (fd *DynamicBytesFileDescriptionImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.preadLocked(ctx, dst, offset, &opts) + fd.mu.Unlock() + return n, err +} + +// Read implements FileDescriptionImpl.Read. +func (fd *DynamicBytesFileDescriptionImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.preadLocked(ctx, dst, fd.off, &opts) + fd.off += n + fd.mu.Unlock() + return n, err +} + +// Seek implements FileDescriptionImpl.Seek. +func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + switch whence { + case linux.SEEK_SET: + // Use offset as given. + case linux.SEEK_CUR: + offset += fd.off + default: + // fs/seq_file:seq_lseek() rejects SEEK_END etc. + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + if offset != fd.lastRead { + // Regenerate the file's contents immediately. Compare + // fs/seq_file.c:seq_lseek() => traverse(). + fd.buf.Reset() + if err := fd.data.Generate(ctx, &fd.buf); err != nil { + fd.buf.Reset() + fd.off = 0 + fd.lastRead = 0 + return 0, err + } + fd.lastRead = offset + } + fd.off = offset + return offset, nil +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go new file mode 100644 index 000000000..511b829fc --- /dev/null +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -0,0 +1,141 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "bytes" + "fmt" + "io" + "sync/atomic" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// fileDescription is the common fd struct which a filesystem implementation +// embeds in all of its file description implementations as required. +type fileDescription struct { + vfsfd FileDescription + FileDescriptionDefaultImpl +} + +// genCountFD is a read-only FileDescriptionImpl representing a regular file +// that contains the number of times its DynamicBytesSource.Generate() +// implementation has been called. +type genCountFD struct { + fileDescription + DynamicBytesFileDescriptionImpl + + count uint64 // accessed using atomic memory ops +} + +func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription { + var fd genCountFD + fd.vfsfd.Init(&fd, mnt, vfsd) + fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd) + return &fd.vfsfd +} + +// Release implements FileDescriptionImpl.Release. +func (fd *genCountFD) Release() { +} + +// StatusFlags implements FileDescriptionImpl.StatusFlags. +func (fd *genCountFD) StatusFlags(ctx context.Context) (uint32, error) { + return 0, nil +} + +// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags. +func (fd *genCountFD) SetStatusFlags(ctx context.Context, flags uint32) error { + return syserror.EPERM +} + +// Stat implements FileDescriptionImpl.Stat. +func (fd *genCountFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + // Note that Statx.Mask == 0 in the return value. + return linux.Statx{}, nil +} + +// SetStat implements FileDescriptionImpl.SetStat. +func (fd *genCountFD) SetStat(ctx context.Context, opts SetStatOptions) error { + return syserror.EPERM +} + +// Generate implements DynamicBytesSource.Generate. +func (fd *genCountFD) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d", atomic.AddUint64(&fd.count, 1)) + return nil +} + +func TestGenCountFD(t *testing.T) { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + + vfsObj := New() // vfs.New() + vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{}) + if err != nil { + t.Fatalf("failed to create testfs root mount: %v", err) + } + vd := mntns.Root() + defer vd.DecRef() + + fd := newGenCountFD(vd.Mount(), vd.Dentry()) + defer fd.DecRef() + + // The first read causes Generate to be called to fill the FD's buffer. + buf := make([]byte, 2) + ioseq := usermem.BytesIOSequence(buf) + n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{}) + if n != 1 || (err != nil && err != io.EOF) { + t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) + } + if want := byte('1'); buf[0] != want { + t.Errorf("first Read: got byte %c, wanted %c", buf[0], want) + } + + // A second read without seeking is still at EOF. + n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + if n != 0 || err != io.EOF { + t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) + } + + // Seeking to the beginning of the file causes it to be regenerated. + n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET) + if n != 0 || err != nil { + t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) + } + n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + if n != 1 || (err != nil && err != io.EOF) { + t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) + } + if want := byte('2'); buf[0] != want { + t.Errorf("Read after Seek: got byte %c, wanted %c", buf[0], want) + } + + // PRead at the beginning of the file also causes it to be regenerated. + n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{}) + if n != 1 || (err != nil && err != io.EOF) { + t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) + } + if want := byte('3'); buf[0] != want { + t.Errorf("PRead: got byte %c, wanted %c", buf[0], want) + } +} diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go new file mode 100644 index 000000000..70b192ece --- /dev/null +++ b/pkg/sentry/vfs/testutil.go @@ -0,0 +1,139 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" +) + +// FDTestFilesystemType is a test-only FilesystemType that produces Filesystems +// for which all FilesystemImpl methods taking a path return EPERM. It is used +// to produce Mounts and Dentries for testing of FileDescriptionImpls that do +// not depend on their originating Filesystem. +type FDTestFilesystemType struct{} + +// FDTestFilesystem is a test-only FilesystemImpl produced by +// FDTestFilesystemType. +type FDTestFilesystem struct { + vfsfs Filesystem +} + +// NewFilesystem implements FilesystemType.NewFilesystem. +func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) { + var fs FDTestFilesystem + fs.vfsfs.Init(&fs) + return &fs.vfsfs, fs.NewDentry(), nil +} + +// Release implements FilesystemImpl.Release. +func (fs *FDTestFilesystem) Release() { +} + +// Sync implements FilesystemImpl.Sync. +func (fs *FDTestFilesystem) Sync(ctx context.Context) error { + return nil +} + +// GetDentryAt implements FilesystemImpl.GetDentryAt. +func (fs *FDTestFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) { + return nil, syserror.EPERM +} + +// LinkAt implements FilesystemImpl.LinkAt. +func (fs *FDTestFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error { + return syserror.EPERM +} + +// MkdirAt implements FilesystemImpl.MkdirAt. +func (fs *FDTestFilesystem) MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error { + return syserror.EPERM +} + +// MknodAt implements FilesystemImpl.MknodAt. +func (fs *FDTestFilesystem) MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error { + return syserror.EPERM +} + +// OpenAt implements FilesystemImpl.OpenAt. +func (fs *FDTestFilesystem) OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error) { + return nil, syserror.EPERM +} + +// ReadlinkAt implements FilesystemImpl.ReadlinkAt. +func (fs *FDTestFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error) { + return "", syserror.EPERM +} + +// RenameAt implements FilesystemImpl.RenameAt. +func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error { + return syserror.EPERM +} + +// RmdirAt implements FilesystemImpl.RmdirAt. +func (fs *FDTestFilesystem) RmdirAt(ctx context.Context, rp *ResolvingPath) error { + return syserror.EPERM +} + +// SetStatAt implements FilesystemImpl.SetStatAt. +func (fs *FDTestFilesystem) SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error { + return syserror.EPERM +} + +// StatAt implements FilesystemImpl.StatAt. +func (fs *FDTestFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts StatOptions) (linux.Statx, error) { + return linux.Statx{}, syserror.EPERM +} + +// StatFSAt implements FilesystemImpl.StatFSAt. +func (fs *FDTestFilesystem) StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error) { + return linux.Statfs{}, syserror.EPERM +} + +// SymlinkAt implements FilesystemImpl.SymlinkAt. +func (fs *FDTestFilesystem) SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error { + return syserror.EPERM +} + +// UnlinkAt implements FilesystemImpl.UnlinkAt. +func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error { + return syserror.EPERM +} + +type fdTestDentry struct { + vfsd Dentry +} + +// NewDentry returns a new Dentry. +func (fs *FDTestFilesystem) NewDentry() *Dentry { + var d fdTestDentry + d.vfsd.Init(&d) + return &d.vfsd +} + +// IncRef implements DentryImpl.IncRef. +func (d *fdTestDentry) IncRef(vfsfs *Filesystem) { +} + +// TryIncRef implements DentryImpl.TryIncRef. +func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool { + return true +} + +// DecRef implements DentryImpl.DecRef. +func (d *fdTestDentry) DecRef(vfsfs *Filesystem) { +} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 047f8329a..df37c7d5a 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -12,6 +12,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/buffer", + "//pkg/tcpip/iptables", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index c40924852..0d2637ee4 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -24,6 +24,7 @@ go_test( embed = [":gonet"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 308f620e5..cd6ce930a 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -404,7 +404,7 @@ func (c *Conn) Write(b []byte) (int, error) { } } - var n uintptr + var n int64 var resCh <-chan struct{} n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) nbytes += int(n) @@ -556,32 +556,50 @@ type PacketConn struct { wq *waiter.Queue } -// NewPacketConn creates a new PacketConn. -func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { - // Create UDP endpoint and bind it. +// DialUDP creates a new PacketConn. +// +// If laddr is nil, a local address is automatically chosen. +// +// If raddr is nil, the PacketConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { return nil, errors.New(err.String()) } - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "udp", - Addr: fullToUDPAddr(addr), - Err: errors.New(err.String()), + if laddr != nil { + if err := ep.Bind(*laddr); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "udp", + Addr: fullToUDPAddr(*laddr), + Err: errors.New(err.String()), + } } } - c := &PacketConn{ + c := PacketConn{ stack: s, ep: ep, wq: &wq, } c.deadlineTimer.init() - return c, nil + + if raddr != nil { + if err := c.ep.Connect(*raddr); err != nil { + c.ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "udp", + Addr: fullToUDPAddr(*raddr), + Err: errors.New(err.String()), + } + } + } + + return &c, nil } func (c *PacketConn) newOpError(op string, err error) *net.OpError { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 39efe44c7..672f026b2 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -26,6 +26,7 @@ import ( "golang.org/x/net/nettest" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -69,17 +70,13 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) { s.SetRouteTable([]tcpip.Route{ // IPv4 { - Destination: tcpip.Address(strings.Repeat("\x00", 4)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)), - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: NICID, }, // IPv6 { - Destination: tcpip.Address(strings.Repeat("\x00", 16)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)), - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: NICID, }, }) @@ -371,9 +368,9 @@ func TestUDPForwarder(t *testing.T) { }) s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } sent := "abc123" @@ -452,13 +449,13 @@ func TestPacketConnTransfer(t *testing.T) { addr2 := tcpip.FullAddress{NICID, ip2, 11311} s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber) + c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 4):", err) + t.Fatal("DialUDP(bind port 4):", err) } - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } c1.SetDeadline(time.Now().Add(time.Second)) @@ -491,6 +488,50 @@ func TestPacketConnTransfer(t *testing.T) { } } +func TestConnectedPacketConnTransfer(t *testing.T) { + s, e := newLoopbackStack() + if e != nil { + t.Fatalf("newLoopbackStack() = %v", e) + } + + ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr := tcpip.FullAddress{NICID, ip, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + + c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 4):", err) + } + c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 5):", err) + } + + c1.SetDeadline(time.Now().Add(time.Second)) + c2.SetDeadline(time.Now().Add(time.Second)) + + sent := "abc123" + if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { + t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) + } + recv := make([]byte, len(sent)) + n, err := c1.Read(recv) + if err != nil || n != len(recv) { + t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) + } + + if recv := string(recv); recv != sent { + t.Errorf("got recv = %q, want = %q", recv, sent) + } + + if err := c1.Close(); err != nil { + t.Error("c1.Close():", err) + } + if err := c2.Close(); err != nil { + t.Error("c2.Close():", err) + } +} + func makePipe() (c1, c2 net.Conn, stop func(), err error) { s, e := newLoopbackStack() if e != nil { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 94a3af289..17fc9c68e 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -111,6 +111,15 @@ const ( IPv4FlagDontFragment ) +// IPv4EmptySubnet is the empty IPv4 subnet. +var IPv4EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) + if err != nil { + panic(err) + } + return subnet +}() + // IPVersion returns the version of IP used in the given packet. It returns -1 // if the packet is not large enough to contain the version field. func IPVersion(b []byte) int { diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 95fe8bfc3..bc4e56535 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -27,7 +27,7 @@ const ( nextHdr = 6 hopLimit = 7 v6SrcAddr = 8 - v6DstAddr = 24 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -82,6 +82,15 @@ const ( IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) +// IPv6EmptySubnet is the empty IPv6 subnet. +var IPv6EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) + if err != nil { + panic(err) + } + return subnet +}() + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { @@ -110,13 +119,13 @@ func (b IPv6) Payload() []byte { // SourceAddress returns the "source address" field of the ipv6 header. func (b IPv6) SourceAddress() tcpip.Address { - return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) + return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize]) } // DestinationAddress returns the "destination address" field of the ipv6 // header. func (b IPv6) DestinationAddress() tcpip.Address { - return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) + return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize]) } // Checksum implements Network.Checksum. Given that IPv6 doesn't have a @@ -144,13 +153,13 @@ func (b IPv6) SetPayloadLength(payloadLength uint16) { // SetSourceAddress sets the "source address" field of the ipv6 header. func (b IPv6) SetSourceAddress(addr tcpip.Address) { - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) + copy(b[v6SrcAddr:][:IPv6AddressSize], addr) } // SetDestinationAddress sets the "destination address" field of the ipv6 // header. func (b IPv6) SetDestinationAddress(addr tcpip.Address) { - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) + copy(b[v6DstAddr:][:IPv6AddressSize], addr) } // SetNextHeader sets the value of the "next header" field of the ipv6 header. @@ -169,8 +178,8 @@ func (b IPv6) Encode(i *IPv6Fields) { b.SetPayloadLength(i.PayloadLength) b[nextHdr] = i.NextHeader b[hopLimit] = i.HopLimit - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) + b.SetSourceAddress(i.SrcAddr) + b.SetDestinationAddress(i.DstAddr) } // IsValid performs basic validation on the packet. diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index fc9abbb55..3fc14bacd 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -11,8 +11,5 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - ], + deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index f1e1d1fad..68c68d4aa 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -32,8 +32,8 @@ const ( // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { - return &IPTables{ +func DefaultTables() IPTables { + return IPTables{ Tables: map[string]Table{ tablenameNat: Table{ BuiltinChains: map[Hook]Chain{ diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 600bd9a10..42a79ef9f 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -15,7 +15,6 @@ package iptables import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -128,15 +127,29 @@ type Table struct { // UserChains, and its purpose is to make looking up tables by name // fast. Chains map[string]*Chain + + // Metadata holds information about the Table that is useful to users + // of IPTables, but not to the netstack IPTables code itself. + metadata interface{} } // ValidHooks returns a bitmap of the builtin hooks for the given table. -func (table *Table) ValidHooks() (uint32, *tcpip.Error) { +func (table *Table) ValidHooks() uint32 { hooks := uint32(0) for hook, _ := range table.BuiltinChains { hooks |= 1 << hook } - return hooks, nil + return hooks +} + +// Metadata returns the metadata object stored in table. +func (table *Table) Metadata() interface{} { + return table.metadata +} + +// SetMetadata sets the metadata object stored in table. +func (table *Table) SetMetadata(metadata interface{}) { + table.metadata = metadata } // A Chain defines a list of rules for packet processing. When a packet diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index d786d8fdf..74fbbb896 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -8,8 +8,8 @@ go_library( "endpoint.go", "endpoint_unsafe.go", "mmap.go", - "mmap_amd64.go", - "mmap_amd64_unsafe.go", + "mmap_stub.go", + "mmap_unsafe.go", "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index fe19c2bc2..8bfeb97e4 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -12,14 +12,183 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !linux !amd64 +// +build linux,amd64 linux,arm64 package fdbased -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + "syscall" -// Stubbed out version for non-linux/non-amd64 platforms. + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" +) -func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, *tcpip.Error) { - return nil, nil +const ( + tPacketAlignment = uintptr(16) + tpStatusKernel = 0 + tpStatusUser = 1 + tpStatusCopy = 2 + tpStatusLosing = 4 +) + +// We overallocate the frame size to accommodate space for the +// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. +// +// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB +// +// NOTE: +// Frames need to be aligned at 16 byte boundaries. +// BlockSize needs to be page aligned. +// +// For details see PACKET_MMAP setting constraints in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +const ( + tpFrameSize = 65536 + 128 + tpBlockSize = tpFrameSize * 32 + tpBlockNR = 1 + tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize +) + +// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct +// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>. +func tPacketAlign(v uintptr) uintptr { + return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) +} + +// tPacketReq is the tpacket_req structure as described in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +type tPacketReq struct { + tpBlockSize uint32 + tpBlockNR uint32 + tpFrameSize uint32 + tpFrameNR uint32 +} + +// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h> +type tPacketHdr []byte + +const ( + tpStatusOffset = 0 + tpLenOffset = 8 + tpSnapLenOffset = 12 + tpMacOffset = 16 + tpNetOffset = 18 + tpSecOffset = 20 + tpUSecOffset = 24 +) + +func (t tPacketHdr) tpLen() uint32 { + return binary.LittleEndian.Uint32(t[tpLenOffset:]) +} + +func (t tPacketHdr) tpSnapLen() uint32 { + return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) +} + +func (t tPacketHdr) tpMac() uint16 { + return binary.LittleEndian.Uint16(t[tpMacOffset:]) +} + +func (t tPacketHdr) tpNet() uint16 { + return binary.LittleEndian.Uint16(t[tpNetOffset:]) +} + +func (t tPacketHdr) tpSec() uint32 { + return binary.LittleEndian.Uint32(t[tpSecOffset:]) +} + +func (t tPacketHdr) tpUSec() uint32 { + return binary.LittleEndian.Uint32(t[tpUSecOffset:]) +} + +func (t tPacketHdr) Payload() []byte { + return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] +} + +// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. +// See: mmap_amd64_unsafe.go for implementation details. +type packetMMapDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // ringBuffer is only used when PacketMMap dispatcher is used and points + // to the start of the mmapped PACKET_RX_RING buffer. + ringBuffer []byte + + // ringOffset is the current offset into the ring buffer where the next + // inbound packet will be placed by the kernel. + ringOffset int +} + +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { + hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) + for hdr.tpStatus()&tpStatusUser == 0 { + event := rawfile.PollEvent{ + FD: int32(d.fd), + Events: unix.POLLIN | unix.POLLERR, + } + if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { + if errno == syscall.EINTR { + continue + } + return nil, rawfile.TranslateErrno(errno) + } + if hdr.tpStatus()&tpStatusCopy != 0 { + // This frame is truncated so skip it after flipping the + // buffer to the kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) + continue + } + } + + // Copy out the packet from the mmapped frame to a locally owned buffer. + pkt := make([]byte, hdr.tpSnapLen()) + copy(pkt, hdr.Payload()) + // Release packet to kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + return pkt, nil +} + +// dispatch reads packets from an mmaped ring buffer and dispatches them to the +// network stack. +func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { + pkt, err := d.readMMappedPacket() + if err != nil { + return false, err + } + var ( + p tcpip.NetworkProtocolNumber + remote, local tcpip.LinkAddress + ) + if d.e.hdrSize > 0 { + eth := header.Ethernet(pkt) + p = eth.Type() + remote = eth.SourceAddress() + local = eth.DestinationAddress() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + switch header.IPVersion(pkt) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + pkt = pkt[d.e.hdrSize:] + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + return true, nil } diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go deleted file mode 100644 index 8bbb4f9ab..000000000 --- a/pkg/tcpip/link/fdbased/mmap_amd64.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux,amd64 - -package fdbased - -import ( - "encoding/binary" - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" -) - -const ( - tPacketAlignment = uintptr(16) - tpStatusKernel = 0 - tpStatusUser = 1 - tpStatusCopy = 2 - tpStatusLosing = 4 -) - -// We overallocate the frame size to accommodate space for the -// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. -// -// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB -// -// NOTE: -// Frames need to be aligned at 16 byte boundaries. -// BlockSize needs to be page aligned. -// -// For details see PACKET_MMAP setting constraints in -// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt -const ( - tpFrameSize = 65536 + 128 - tpBlockSize = tpFrameSize * 32 - tpBlockNR = 1 - tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize -) - -// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct -// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>. -func tPacketAlign(v uintptr) uintptr { - return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) -} - -// tPacketReq is the tpacket_req structure as described in -// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt -type tPacketReq struct { - tpBlockSize uint32 - tpBlockNR uint32 - tpFrameSize uint32 - tpFrameNR uint32 -} - -// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h> -type tPacketHdr []byte - -const ( - tpStatusOffset = 0 - tpLenOffset = 8 - tpSnapLenOffset = 12 - tpMacOffset = 16 - tpNetOffset = 18 - tpSecOffset = 20 - tpUSecOffset = 24 -) - -func (t tPacketHdr) tpLen() uint32 { - return binary.LittleEndian.Uint32(t[tpLenOffset:]) -} - -func (t tPacketHdr) tpSnapLen() uint32 { - return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) -} - -func (t tPacketHdr) tpMac() uint16 { - return binary.LittleEndian.Uint16(t[tpMacOffset:]) -} - -func (t tPacketHdr) tpNet() uint16 { - return binary.LittleEndian.Uint16(t[tpNetOffset:]) -} - -func (t tPacketHdr) tpSec() uint32 { - return binary.LittleEndian.Uint32(t[tpSecOffset:]) -} - -func (t tPacketHdr) tpUSec() uint32 { - return binary.LittleEndian.Uint32(t[tpUSecOffset:]) -} - -func (t tPacketHdr) Payload() []byte { - return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] -} - -// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. -// See: mmap_amd64_unsafe.go for implementation details. -type packetMMapDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - - // ringBuffer is only used when PacketMMap dispatcher is used and points - // to the start of the mmapped PACKET_RX_RING buffer. - ringBuffer []byte - - // ringOffset is the current offset into the ring buffer where the next - // inbound packet will be placed by the kernel. - ringOffset int -} - -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { - hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) - for hdr.tpStatus()&tpStatusUser == 0 { - event := rawfile.PollEvent{ - FD: int32(d.fd), - Events: unix.POLLIN | unix.POLLERR, - } - if _, errno := rawfile.BlockingPoll(&event, 1, -1); errno != 0 { - if errno == syscall.EINTR { - continue - } - return nil, rawfile.TranslateErrno(errno) - } - if hdr.tpStatus()&tpStatusCopy != 0 { - // This frame is truncated so skip it after flipping the - // buffer to the kernel. - hdr.setTPStatus(tpStatusKernel) - d.ringOffset = (d.ringOffset + 1) % tpFrameNR - hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) - continue - } - } - - // Copy out the packet from the mmapped frame to a locally owned buffer. - pkt := make([]byte, hdr.tpSnapLen()) - copy(pkt, hdr.Payload()) - // Release packet to kernel. - hdr.setTPStatus(tpStatusKernel) - d.ringOffset = (d.ringOffset + 1) % tpFrameNR - return pkt, nil -} - -// dispatch reads packets from an mmaped ring buffer and dispatches them to the -// network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { - pkt, err := d.readMMappedPacket() - if err != nil { - return false, err - } - var ( - p tcpip.NetworkProtocolNumber - remote, local tcpip.LinkAddress - ) - if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) - p = eth.Type() - remote = eth.SourceAddress() - local = eth.DestinationAddress() - } else { - // We don't get any indication of what the packet is, so try to guess - // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(pkt) { - case header.IPv4Version: - p = header.IPv4ProtocolNumber - case header.IPv6Version: - p = header.IPv6ProtocolNumber - default: - return true, nil - } - } - - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) - return true, nil -} diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go new file mode 100644 index 000000000..67be52d67 --- /dev/null +++ b/pkg/tcpip/link/fdbased/mmap_stub.go @@ -0,0 +1,23 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !linux !amd64,!arm64 + +package fdbased + +// Stubbed out version for non-linux/non-amd64/non-arm64 platforms. + +func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + return nil, nil +} diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 47cb1d1cc..3894185ae 100644 --- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 +// +build linux,amd64 linux,arm64 package fdbased diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6e3a7a9d7..088eb8a21 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -6,8 +6,10 @@ go_library( name = "rawfile", srcs = [ "blockingpoll_amd64.s", - "blockingpoll_amd64_unsafe.go", + "blockingpoll_arm64.s", + "blockingpoll_noyield_unsafe.go", "blockingpoll_unsafe.go", + "blockingpoll_yield_unsafe.go", "errors.go", "rawfile_unsafe.go", ], diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s index b54131573..298bad55d 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s @@ -14,17 +14,18 @@ #include "textflag.h" -// BlockingPoll makes the poll() syscall while calling the version of +// BlockingPoll makes the ppoll() syscall while calling the version of // entersyscall that relinquishes the P so that other Gs can run. This is meant // to be called in cases when the syscall is expected to block. // -// func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (n int, err syscall.Errno) +// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno) TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 CALL ·callEntersyscallblock(SB) MOVQ fds+0(FP), DI MOVQ nfds+8(FP), SI MOVQ timeout+16(FP), DX - MOVQ $0x7, AX // SYS_POLL + MOVQ $0x0, R10 // sigmask parameter which isn't used here + MOVQ $0x10f, AX // SYS_PPOLL SYSCALL CMPQ AX, $0xfffffffffffff001 JLS ok diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s new file mode 100644 index 000000000..b62888b93 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// BlockingPoll makes the ppoll() syscall while calling the version of +// entersyscall that relinquishes the P so that other Gs can run. This is meant +// to be called in cases when the syscall is expected to block. +// +// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno) +TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 + BL ·callEntersyscallblock(SB) + MOVD fds+0(FP), R0 + MOVD nfds+8(FP), R1 + MOVD timeout+16(FP), R2 + MOVD $0x0, R3 // sigmask parameter which isn't used here + MOVD $0x49, R8 // SYS_PPOLL + SVC + CMP $0xfffffffffffff001, R0 + BLS ok + MOVD $-1, R1 + MOVD R1, n+24(FP) + NEG R0, R0 + MOVD R0, err+32(FP) + BL ·callExitsyscall(SB) + RET +ok: + MOVD R0, n+24(FP) + MOVD $0, err+32(FP) + BL ·callExitsyscall(SB) + RET diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go new file mode 100644 index 000000000..621ab8d29 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go @@ -0,0 +1,31 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux,!amd64,!arm64 + +package rawfile + +import ( + "syscall" + "unsafe" +) + +// BlockingPoll is just a stub function that forwards to the ppoll() system call +// on non-amd64 and non-arm64 platforms. +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) { + n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)), + uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0) + + return int(n), e +} diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go index 4eab77c74..84dc0e918 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go @@ -21,9 +21,11 @@ import ( "unsafe" ) -// BlockingPoll is just a stub function that forwards to the poll() system call +// BlockingPoll is just a stub function that forwards to the ppoll() system call // on non-amd64 platforms. -func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) { - n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout)) +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) { + n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)), + uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0) + return int(n), e } diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index c87268610..dda3b10a6 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 +// +build linux,amd64 linux,arm64 // +build go1.12 // +build !go1.14 @@ -25,8 +25,14 @@ import ( _ "unsafe" // for go:linkname ) +// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the +// version of entersyscall that relinquishes the P so that other Gs can +// run. This is meant to be called in cases when the syscall is expected to +// block. On non amd64/arm64 platforms it just forwards to the ppoll() system +// call. +// //go:noescape -func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) // Use go:linkname to call into the runtime. As of Go 1.12 this has to // be done from Go code so that we make an ABIInternal call to an diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index e3fbb15c2..7e286a3a6 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -123,7 +123,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { Events: 1, // POLLIN } - _, e = BlockingPoll(&event, 1, -1) + _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != syscall.EINTR { return 0, TranslateErrno(e) } @@ -145,7 +145,7 @@ func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { Events: 1, // POLLIN } - _, e = BlockingPoll(&event, 1, -1) + _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != syscall.EINTR { return 0, TranslateErrno(e) } @@ -175,7 +175,7 @@ func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { Events: 1, // POLLIN } - if _, e := BlockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR { + if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != syscall.EINTR { return 0, TranslateErrno(e) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index fc584c6a4..36c8c46fc 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -360,10 +360,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() + details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) + size -= header.UDPMinimumSize } - size -= header.UDPMinimumSize - - details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) case header.TCPProtocolNumber: transName = "tcp" diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 60070874d..fd6395fc1 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -109,13 +109,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { pkt.SetOp(header.ARPReply) copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) - fallthrough // also fill the cache from requests case header.ARPReply: - addr := tcpip.Address(h.ProtocolAddressSender()) - linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 66c55821b..4c4b54469 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "strconv" "testing" "time" @@ -65,9 +66,7 @@ func newTestContext(t *testing.T) *testContext { } s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }}) @@ -101,40 +100,30 @@ func TestDirectRequest(t *testing.T) { c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) } - inject(stackAddr1) - { - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { - t.Errorf("stackAddr1: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) - } - } - - inject(stackAddr2) - { - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { - t.Errorf("stackAddr2: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) - } + for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { + t.Run(strconv.Itoa(i), func(t *testing.T) { + inject(address) + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) + } + }) } inject(stackAddrBad) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 55e9eec99..6bbfcd97f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -173,8 +173,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ - Destination: ipv4SubnetAddr, - Mask: ipv4SubnetMask, + Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) @@ -187,8 +186,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ - Destination: ipv6SubnetAddr, - Mask: ipv6SubnetMask, + Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index fbef6947d..497164cbb 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -94,6 +94,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv4EchoReply) + pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3207a3d46..1b5a55bea 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -52,9 +52,7 @@ func TestExcludeBroadcast(t *testing.T) { } s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }}) @@ -247,14 +245,22 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) linkEPId := stack.RegisterLinkEndpoint(linkEP) s.CreateNIC(1, linkEPId) - s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01") - s.SetRouteTable([]tcpip.Route{{ - Destination: "\x10\x00\x00\x02", - Mask: "\xff\xff\xff\xff", - Gateway: "", - NIC: 1, - }}) - r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("s.FindRoute got %v, want %v", err, nil) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 5e6a59e91..1689af16f 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -100,13 +100,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return @@ -146,7 +144,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 726362c87..d0dc72506 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -91,13 +91,18 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } } - s.SetRouteTable( - []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), - NIC: 1, - }}, - ) + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } netProto := s.NetworkProtocolInstance(ProtocolNumber) if netProto == nil { @@ -237,17 +242,23 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress sn lladdr1: %v", err) } + subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } c.s0.SetRouteTable( []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + Destination: subnet0, NIC: 1, }}, ) + subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } c.s1.SetRouteTable( []tcpip.Route{{ - Destination: lladdr0, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + Destination: subnet1, NIC: 1, }}, ) diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index 996939581..a57752a7c 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,6 +8,7 @@ go_binary( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3ac381631..e2021cd15 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -52,6 +52,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -152,9 +153,7 @@ func main() { // Add default route. s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index da425394a..1716be285 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -149,12 +149,15 @@ func main() { log.Fatal(err) } + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + if err != nil { + log.Fatal(err) + } + // Add default route. s.SetRouteTable([]tcpip.Route{ { - Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), - Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), - Gateway: "", + Destination: subnet, NIC: 1, }, }) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 28d11c797..b692c60ce 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,11 +1,25 @@ package(licenses = ["notice"]) +load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +go_template_instance( + name = "linkaddrentry_list", + out = "linkaddrentry_list.go", + package = "stack", + prefix = "linkAddrEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*linkAddrEntry", + "Linker": "*linkAddrEntry", + }, +) + go_library( name = "stack", srcs = [ "linkaddrcache.go", + "linkaddrentry_list.go", "nic.go", "registration.go", "route.go", @@ -24,6 +38,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", @@ -42,6 +57,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/waiter", @@ -58,3 +74,11 @@ go_test( "//pkg/tcpip", ], ) + +filegroup( + name = "autogen", + srcs = [ + "linkaddrentry_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 77bb0ccb9..267df60d1 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -42,10 +42,11 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - mu sync.Mutex - cache map[tcpip.FullAddress]*linkAddrEntry - next int // array index of next available entry - entries [linkAddrCacheSize]linkAddrEntry + cache struct { + sync.Mutex + table map[tcpip.FullAddress]*linkAddrEntry + lru linkAddrEntryList + } } // entryState controls the state of a single entry in the cache. @@ -60,9 +61,6 @@ const ( // failed means that address resolution timed out and the address // could not be resolved. failed - // expired means that the cache entry has expired and the address must be - // resolved again. - expired ) // String implements Stringer. @@ -74,8 +72,6 @@ func (s entryState) String() string { return "ready" case failed: return "failed" - case expired: - return "expired" default: return fmt.Sprintf("unknown(%d)", s) } @@ -84,64 +80,46 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + linkAddrEntryEntry + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of 'incomplete' these waiters are notified. + // state transitions out of incomplete these waiters are notified. wakers map[*sleep.Waker]struct{} + // done is used to allow callers to wait on address resolution. It is nil iff + // s is incomplete and resolution is not yet in progress. done chan struct{} } -func (e *linkAddrEntry) state() entryState { - if e.s != expired && time.Now().After(e.expiration) { - // Force the transition to ensure waiters are notified. - e.changeState(expired) - } - return e.s -} - -func (e *linkAddrEntry) changeState(ns entryState) { - if e.s == ns { - return - } - - // Validate state transition. - switch e.s { - case incomplete: - // All transitions are valid. - case ready, failed: - if ns != expired { - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - } - case expired: - // Terminal state. - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - default: - panic(fmt.Sprintf("invalid state: %s", e.s)) - } - +// changeState sets the entry's state to ns, notifying any waiters. +// +// The entry's expiration is bumped up to the greater of itself and the passed +// expiration; the zero value indicates immediate expiration, and is set +// unconditionally - this is an implementation detail that allows for entries +// to be reused. +func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { // Notify whoever is waiting on address resolution when transitioning - // out of 'incomplete'. - if e.s == incomplete { + // out of incomplete. + if e.s == incomplete && ns != incomplete { for w := range e.wakers { w.Assert() } e.wakers = nil - if e.done != nil { - close(e.done) + if ch := e.done; ch != nil { + close(ch) } + e.done = nil } - e.s = ns -} -func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) { - if w != nil { - e.wakers[w] = struct{}{} + if expiration.IsZero() || expiration.After(e.expiration) { + e.expiration = expiration } + e.s = ns } func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { @@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] - if ok { - s := entry.state() - if s != expired && entry.linkAddr == v { - // Disregard repeated calls. - return - } - // Check if entry is waiting for address resolution. - if s == incomplete { - entry.linkAddr = v - } else { - // Otherwise create a new entry to replace it. - entry = c.makeAndAddEntry(k, v) - } - } else { - entry = c.makeAndAddEntry(k, v) - } + // Calculate expiration time before acquiring the lock, since expiration is + // relative to the time when information was learned, rather than when it + // happened to be inserted into the cache. + expiration := time.Now().Add(c.ageLimit) - entry.changeState(ready) + c.cache.Lock() + entry := c.getOrCreateEntryLocked(k) + entry.linkAddr = v + + entry.changeState(ready, expiration) + c.cache.Unlock() } -// makeAndAddEntry is a helper function to create and add a new -// entry to the cache map and evict older entry as needed. -func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { - // Take over the next entry. - entry := &c.entries[c.next] - if c.cache[entry.addr] == entry { - delete(c.cache, entry.addr) +// getOrCreateEntryLocked retrieves a cache entry associated with k. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { + if entry, ok := c.cache.table[k]; ok { + c.cache.lru.Remove(entry) + c.cache.lru.PushFront(entry) + return entry } + var entry *linkAddrEntry + if len(c.cache.table) == linkAddrCacheSize { + entry = c.cache.lru.Back() - // Mark the soon-to-be-replaced entry as expired, just in case there is - // someone waiting for address resolution on it. - entry.changeState(expired) + delete(c.cache.table, entry.addr) + c.cache.lru.Remove(entry) - *entry = linkAddrEntry{ - addr: k, - linkAddr: v, - expiration: time.Now().Add(c.ageLimit), - wakers: make(map[*sleep.Waker]struct{}), - done: make(chan struct{}), + // Wake waiters and mark the soon-to-be-reused entry as expired. Note + // that the state passed doesn't matter when the zero time is passed. + entry.changeState(failed, time.Time{}) + } else { + entry = new(linkAddrEntry) } - c.cache[k] = entry - c.next = (c.next + 1) % len(c.entries) + *entry = linkAddrEntry{ + addr: k, + s: incomplete, + } + c.cache.table[k] = entry + c.cache.lru.PushFront(entry) return entry } @@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.cache[k]; ok { - switch s := entry.state(); s { - case expired: - case ready: - return entry.linkAddr, nil, nil - case failed: - return "", nil, tcpip.ErrNoLinkAddress - case incomplete: - // Address resolution is still in progress. - entry.maybeAddWaker(waker) - return "", entry.done, tcpip.ErrWouldBlock - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + c.cache.Lock() + defer c.cache.Unlock() + entry := c.getOrCreateEntryLocked(k) + switch s := entry.s; s { + case ready, failed: + if !time.Now().After(entry.expiration) { + // Not expired. + switch s { + case ready: + return entry.linkAddr, nil, nil + case failed: + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } - } - if linkRes == nil { - return "", nil, tcpip.ErrNoLinkAddress - } + entry.changeState(incomplete, time.Time{}) + fallthrough + case incomplete: + if waker != nil { + if entry.wakers == nil { + entry.wakers = make(map[*sleep.Waker]struct{}) + } + entry.wakers[waker] = struct{}{} + } - // Add 'incomplete' entry in the cache to mark that resolution is in progress. - e := c.makeAndAddEntry(k, "") - e.maybeAddWaker(waker) + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + } - go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + entry.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + } - return "", e.done, tcpip.ErrWouldBlock + return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } // removeWaker removes a waker previously added through get(). func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.mu.Lock() - defer c.mu.Unlock() + c.cache.Lock() + defer c.cache.Unlock() - if entry, ok := c.cache[k]; ok { + if entry, ok := c.cache.table[k]; ok { entry.removeWaker(waker) } } @@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) select { - case <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(k, i); stop { + case now := <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(now, k, i); stop { return } case <-done: @@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has succeeded // and mark the entry accordingly, e.g. ready, failed, etc. Return true if request // can stop, false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { + c.cache.Lock() + defer c.cache.Unlock() + entry, ok := c.cache.table[k] if !ok { // Entry was evicted from the cache. return true } - - switch s := entry.state(); s { - case ready, failed, expired: + switch s := entry.s; s { + case ready, failed: // Entry was made ready by resolver or failed. Either way we're done. - return true case incomplete: - if attempt+1 >= c.resolutionAttempts { - // Max number of retries reached, mark entry as failed. - entry.changeState(failed) - return true + if attempt+1 < c.resolutionAttempts { + // No response yet, need to send another ARP request. + return false } - // No response yet, need to send another ARP request. - return false + // Max number of retries reached, mark entry as failed. + entry.changeState(failed, now.Add(c.ageLimit)) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } + return true } func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - return &linkAddrCache{ + c := &linkAddrCache{ ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, - cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), } + c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 924f4d240..9946b8fe8 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -29,25 +30,34 @@ type testaddr struct { linkAddr tcpip.LinkAddress } -var testaddrs []testaddr +var testAddrs = func() []testaddr { + var addrs []testaddr + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + addrs = append(addrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } + return addrs +}() type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration + cache *linkAddrCache + delay time.Duration + onLinkAddressRequest func() } func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() + time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) + if f := r.onLinkAddressRequest; f != nil { + f() + } return nil } func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { + for _, ta := range testAddrs { if ta.addr.Addr == addr { r.cache.add(ta.addr, ta.linkAddr) break @@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= 0; i-- { + e := testAddrs[i] c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { @@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) { } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] + e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) @@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { + e := testAddrs[i] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } @@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) { for r := 0; r < 16; r++ { wg.Add(1) go func() { - for _, e := range testaddrs { + for _, e := range testAddrs { c.add(e.addr, e.linkAddr) c.get(e.addr, nil, "", nil, nil) // make work for gotsan } @@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } - e = testaddrs[0] + e = testAddrs[0] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } @@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) { func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { @@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) { func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) @@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) { func TestCacheResolution(t *testing.T) { c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { + for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) @@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) { c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} + var requestCount uint32 + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) + } + // First, sanity check that resolution is working... - e := testaddrs[0] + e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } + before := atomic.LoadUint32(&requestCount) + e.addr.Addr += "2" if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } + + if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } } func TestCacheResolutionTimeout(t *testing.T) { @@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) { c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - e := testaddrs[0] + e := testAddrs[0] if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3e6ff4afb..89b4c5960 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add if list, ok := n.primary[protocol]; ok { for e := list.Front(); e != nil; e = e.Next() { ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { + if ref.kind == permanent && ref.tryIncRef() { r = ref break } @@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -186,82 +186,155 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) +} + +// getRefEpOrCreateTemp returns the referenced network endpoint for the given +// protocol and address. If none exists a temporary one may be created if +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { - ref = nil + + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.kind { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } + } + + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous + if !createTempEP { + for _, sn := range n.subnets { + if sn.Contains(address) { + createTempEP = true + break + } + } } - spoofing := n.spoofing + n.mu.RUnlock() - if ref != nil || !spoofing { - return ref + if !createTempEP { + return nil } // Try again with the lock in exclusive mode. If we still can't get the // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - ref = n.endpoints[id] - if ref == nil || !ref.tryIncRef() { - if netProto, ok := n.stack.networkProtocols[protocol]; ok { - addrWithPrefix := tcpip.AddressWithPrefix{address, netProto.DefaultPrefixLen()} - ref, _ = n.addAddressLocked(protocol, addrWithPrefix, peb, true) - if ref != nil { - ref.holdsInsertRef = false - } + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } - n.mu.Unlock() - return ref -} -func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - return nil, tcpip.ErrUnknownProtocol + n.mu.Unlock() + return nil } + ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb, temporary) - // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, addrWithPrefix, n.stack, n, n.linkEP) - if err != nil { - return nil, err - } + n.mu.Unlock() + return ref +} - id := *ep.ID() +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { - if !replace { + switch ref.kind { + case permanent: + // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.kind = permanent + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} - n.removeEndpointLocked(ref) +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + // Create the new network endpoint. + ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP) + if err != nil { + return nil, err + } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if _, ok := n.stack.linkAddrResolvers[protocol]; ok { + if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { ref.linkCache = n.stack } } n.endpoints[id] = ref - l, ok := n.primary[protocol] + l, ok := n.primary[protocolAddress.Protocol] if !ok { l = &ilist.List{} - n.primary[protocol] = l + n.primary[protocolAddress.Protocol] = l } switch peb { @@ -276,10 +349,10 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPre // AddAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocol, addrWithPrefix, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err @@ -291,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.kind { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -356,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet { func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.kind == permanent { panic("Reference count dropped to zero before being removed") } @@ -381,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { + if r == nil || r.kind != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false - + r.kind = permanentExpired r.decRefLocked() return nil @@ -398,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -414,8 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()} - if _, err := n.addAddressLocked(protocol, addrWithPrefix, NeverPrimaryEndpoint, false); err != nil { + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -437,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -445,6 +531,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -472,6 +565,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) + n.stack.AddLinkAddress(n.id, src, remote) + // If the packet is destined to the IPv4 Broadcast address, then make a // route to each IPv4 network endpoint and let each endpoint handle the // packet. @@ -479,11 +574,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // n.endpoints is mutex protected so acquire lock. n.mu.RLock() for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) } } n.mu.RUnlock() @@ -491,10 +583,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr } if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } @@ -517,8 +606,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -543,52 +633,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.stats.IP.InvalidAddressesReceived.Increment() } -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - id := NetworkEndpointID{dst} - - n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - - promiscuous := n.promiscuous - // Check if the packet is for a subnet this NIC cares about. - if !promiscuous { - for _, sn := range n.subnets { - if sn.Contains(dst) { - promiscuous = true - break - } - } - } - n.mu.RUnlock() - if promiscuous { - // Try again with the lock in exclusive mode. If we still can't - // get the endpoint, create a new "temporary" one. It will only - // exist while there's a route through it. - n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref - } - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.mu.Unlock() - return nil - } - addrWithPrefix := tcpip.AddressWithPrefix{dst, netProto.DefaultPrefixLen()} - ref, err := n.addAddressLocked(protocol, addrWithPrefix, CanBePrimaryEndpoint, true) - n.mu.Unlock() - if err == nil { - ref.holdsInsertRef = false - return ref - } - } - - return nil -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { @@ -676,9 +720,33 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +type networkEndpointKind int + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -687,11 +755,25 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + kind networkEndpointKind +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.kind != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.kind != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 391ab4344..e52cdd674 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. func (r *Route) IsResolutionRequired() bool { - return r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 57b8a9994..d69162ba1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" @@ -333,6 +334,15 @@ type TCPEndpointState struct { Sender TCPSenderState } +// ResumableEndpoint is an endpoint that needs to be resumed after restore. +type ResumableEndpoint interface { + // Resume resumes an endpoint after restore. This can be used to restart + // background workers such as protocol goroutines. This must be called after + // all indirect dependencies of the endpoint has been restored, which + // generally implies at the end of the restore process. + Resume(*Stack) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -372,6 +382,13 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool + + // tables are the iptables packet filtering and manipulation rules. + tables iptables.IPTables + + // resumableEndpoints is a list of endpoints that need to be resumed if the + // stack is being restored. + resumableEndpoints []ResumableEndpoint } // Options contains optional Stack configuration. @@ -751,10 +768,10 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } -// AddAddressWithPrefix adds a new network-layer address/prefixLen to the +// AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix) *tcpip.Error { - return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, CanBePrimaryEndpoint) +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { + return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) } // AddAddressWithOptions is the same as AddAddress, but allows you to specify @@ -764,13 +781,18 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt if !ok { return tcpip.ErrUnknownProtocol } - addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()} - return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, peb) -} - -// AddAddressWithPrefixAndOptions is the same as AddAddressWithPrefixLen, -// but allows you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error { + return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb) +} + +// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows +// you to specify whether the new endpoint can be primary or not. +func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -779,7 +801,7 @@ func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.Ne return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocol, addrWithPrefix, peb) + return nic.AddAddress(protocolAddress, peb) } // AddSubnet adds a subnet range to the specified NIC. @@ -873,7 +895,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { @@ -1082,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip } } +// RegisterRestoredEndpoint records e as an endpoint that has been restored on +// this stack. +func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { + s.mu.Lock() + s.resumableEndpoints = append(s.resumableEndpoints, e) + s.mu.Unlock() +} + +// Resume restarts the stack after a restore. This must be called after the +// entire system has been restored. +func (s *Stack) Resume() { + // ResumableEndpoint.Resume() may call other methods on s, so we can't hold + // s.mu while resuming the endpoints. + s.mu.Lock() + eps := s.resumableEndpoints + s.resumableEndpoints = nil + s.mu.Unlock() + for _, e := range eps { + e.Resume(s) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1161,3 +1205,13 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC } return tcpip.ErrUnknownNICID } + +// IPTables returns the stack's iptables. +func (s *Stack) IPTables() iptables.IPTables { + return s.tables +} + +// SetIPTables sets the stack's iptables. +func (s *Stack) SetIPTables(ipt iptables.IPTables) { + s.tables = ipt +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9d082bba4..4debd1eec 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } +func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { + return f.packetCount[int(intfAddr)%len(f.packetCount)] +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } @@ -188,7 +192,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, - id: stack.NetworkEndpointID{addrWithPrefix.Address}, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, @@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatal("FindRoute failed:", err) + return err } defer r.Release() + return send(r, payload) +} +func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Error("WritePacket failed:", err) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) +} + +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := sendTo(s, addr, payload); err != nil { + t.Error("sendTo failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("sendTo packet count: got = %d, want %d", got, want) + } +} + +func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := send(r, payload); err != nil { + t.Error("send failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("send packet count: got = %d, want %d", got, want) + } +} + +func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := send(r, payload); gotErr != wantErr { + t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := sendTo(s, addr, payload); gotErr != wantErr { + t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + 1 + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we do NOT expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { + t.Helper() + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if got := fakeNet.PacketCount(localAddrByte); got != want { + t.Errorf("receive packet count: got = %d, want %d", got, want) } } @@ -312,17 +375,20 @@ func TestNetworkSend(t *testing.T) { t.Fatal("NewNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) } // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x03", linkEP, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -360,24 +426,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { // Set a route table that sends all packets with odd destination // addresses through the first NIC, and all even destination address // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + }) + } // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x05", linkEP1, nil) // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x06", linkEP2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -439,10 +507,20 @@ func TestRoutes(t *testing.T) { // Set a route table that sends all packets with odd destination // addresses through the first NIC, and all even destination address // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + }) + } // Test routes to odd address. testRoute(t, s, 0, "", "\x05", "\x01") @@ -472,6 +550,10 @@ func TestRoutes(t *testing.T) { } func TestAddressRemoval(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -479,99 +561,285 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } -func TestDelayedRemovalDueToRoute(t *testing.T) { +func TestAddressRemovalWithRouteHeld(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } +} - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) +func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { + t.Helper() + info, ok := s.NICInfo()[nicid] + if !ok { + t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + } + if len(addr) == 0 { + // No address given, verify that there is no address assigned to the NIC. + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + } + } + return + } + // Address given, verify the address is assigned to the NIC and no other + // address is. + found := false + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber { + if a.AddressWithPrefix.Address == addr { + found = true + } else { + t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) + } + } + } + if !found { + t.Errorf("verify address: couldn't find %s on the NIC", addr) + } +} + +func TestEndpointExpiration(t *testing.T) { + const ( + localAddrByte byte = 0x01 + remoteAddr tcpip.Address = "\x03" + noAddr tcpip.Address = "" + nicid tcpip.NICID = 1 + ) + localAddr := tcpip.Address([]byte{localAddrByte}) + + for _, promiscuous := range []bool{true, false} { + for _, spoofing := range []bool{true, false} { + t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + buf[0] = localAddrByte + + if promiscuous { + if err := s.SetPromiscuousMode(nicid, true); err != nil { + t.Fatal("SetPromiscuousMode failed:", err) + } + } + + if spoofing { + if err := s.SetSpoofing(nicid, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + } + + // 1. No Address yet, send should only work for spoofing, receive for + // promiscuous mode. + //----------------------- + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 2. Add Address, everything should work. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 3. Remove the address, send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 4. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 5. Take a reference to the endpoint by getting a route. Verify that + // we can still send/receive, including sending using the route. + //----------------------- + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 6. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 7. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 8. Remove the route, sendTo/recv should still work. + //----------------------- + r.Release() + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 9. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + }) + } } } @@ -583,9 +851,13 @@ func TestPromiscuousMode(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -593,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + const localAddrByte byte = 0x01 + buf[0] = localAddrByte + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -621,54 +886,120 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) +} - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) +func TestSpoofingWithAddress(t *testing.T) { + localAddr := tcpip.Address("\x01") + nonExistentLocalAddr := tcpip.Address("\x02") + dstAddr := tcpip.Address("\x03") + + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } + // Sending a packet works. + testSendTo(t, s, dstAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // FindRoute should also work with a local address that exists on the NIC. + r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != localAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } + // Sending a packet using the route works. + testSend(t, r, linkEP, nil) } -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") +func TestSpoofingNoAddress(t *testing.T) { + nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") + id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { - t.Fatal("AddAddress failed:", err) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } + // Sending a packet fails. + testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. if err := s.SetSpoofing(1, true); err != nil { t.Fatal("SetSpoofing failed:", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } + // Sending a packet works. + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { @@ -806,16 +1137,20 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -824,9 +1159,52 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) +} + +// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func TestCheckLocalAddressForSubnet(t *testing.T) { + const nicID tcpip.NICID = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) + } + + subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) + + if err != nil { + t.Fatal("NewSubnet failed:", err) + } + if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddSubnet failed:", err) + } + + // Loop over all subnet addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + addr := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) + } + addr[0]++ + } + + // Trying the next address should fail since it is outside the subnet range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) } } @@ -839,16 +1217,20 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -856,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } func TestNetworkOptions(t *testing.T) { @@ -969,12 +1348,18 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) if behavior == stack.CanBePrimaryEndpoint { - addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8} - if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil { - t.Fatal("AddAddressWithPrefixAndOptions failed:", err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: addrLen * 8, + }, + } + if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { + t.Fatal("AddProtocolAddressWithOptions failed:", err) } // Remember the address/prefix. - primaryAddrAdded[addressWithPrefix] = struct{}{} + primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { t.Fatal("AddAddressWithOptions failed:", err) @@ -1024,20 +1409,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, } { t.Run(tc.name, func(t *testing.T) { - addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen} - - if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil { - t.Fatal("AddAddressWithPrefix failed:", err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tc.address, + PrefixLen: tc.prefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddress); err != nil { + t.Fatal("AddProtocolAddress failed:", err) } // Check that we get the right initial address and prefix length. if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != addressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix) + } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) } - if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil { + if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } @@ -1102,7 +1492,7 @@ func TestAddAddress(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } -func TestAddAddressWithPrefix(t *testing.T) { +func TestAddProtocolAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") @@ -1116,14 +1506,17 @@ func TestAddAddressWithPrefix(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) for _, addrLen := range addrLenRange { for _, prefixLen := range prefixLenRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil { - t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addrGen.next(addrLen), + PrefixLen: prefixLen, + }, } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, - }) + if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) + } + expectedAddresses = append(expectedAddresses, protocolAddress) } } @@ -1160,7 +1553,7 @@ func TestAddAddressWithOptions(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } -func TestAddAddressWithPrefixAndOptions(t *testing.T) { +func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") @@ -1176,14 +1569,17 @@ func TestAddAddressWithPrefixAndOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, prefixLen := range prefixLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil { - t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addrGen.next(addrLen), + PrefixLen: prefixLen, + }, } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, - }) + if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + } + expectedAddresses = append(expectedAddresses, protocolAddress) } } } @@ -1196,15 +1592,19 @@ func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatal("CreateNIC failed:", err) + t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) } // Route all packets for address \x01 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\xff", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x01", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } // Send a packet to address 1. buf := buffer.NewView(30) @@ -1219,7 +1619,9 @@ func TestNICStats(t *testing.T) { payload := buffer.NewView(10) // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) + if err := sendTo(s, "\x01", payload); err != nil { + t.Fatal("sendTo failed: ", err) + } want := uint64(linkEP1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) @@ -1253,9 +1655,13 @@ func TestNICForwarding(t *testing.T) { } // Route all packets to address 3 to NIC 2. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - }) + { + subnet, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) + } // Send a packet to address 3. buf := buffer.NewView(30) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index b418db046..5335897f5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -64,7 +65,7 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } @@ -78,10 +79,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -104,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr @@ -200,6 +206,13 @@ func (f *fakeTransportEndpoint) State() uint32 { func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { } +func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { + return iptables.IPTables{}, nil +} + +func (f *fakeTransportEndpoint) Resume(*stack.Stack) { +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool @@ -271,7 +284,13 @@ func TestTransportReceive(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress failed: %v", err) @@ -327,7 +346,13 @@ func TestTransportControlReceive(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress failed: %v", err) @@ -393,7 +418,13 @@ func TestTransportSend(t *testing.T) { t.Fatalf("AddAddress failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } // Create endpoint and bind it. wq := waiter.Queue{} @@ -484,10 +515,20 @@ func TestTransportForwarding(t *testing.T) { // Route all packets to address 3 to NIC 2 and all packets to address // 1 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - {"\x01", "\xff", "\x00", 1}, - }) + { + subnet0, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + }) + } wq := waiter.Queue{} ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4208c0303..8f9b86cce 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "math/bits" "reflect" "strconv" "strings" @@ -39,6 +40,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" ) @@ -144,8 +146,17 @@ type Address string type AddressMask string // String implements Stringer. -func (a AddressMask) String() string { - return Address(a).String() +func (m AddressMask) String() string { + return Address(m).String() +} + +// Prefix returns the number of bits before the first host bit. +func (m AddressMask) Prefix() int { + p := 0 + for _, b := range []byte(m) { + p += bits.LeadingZeros8(^b) + } + return p } // Subnet is a subnet defined by its address and mask. @@ -167,6 +178,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) { return Subnet{a, m}, nil } +// String implements Stringer. +func (s Subnet) String() string { + return fmt.Sprintf("%s/%d", s.ID(), s.Prefix()) +} + // Contains returns true iff the address is of the same length and matches the // subnet address and mask. func (s *Subnet) Contains(a Address) bool { @@ -189,28 +205,13 @@ func (s *Subnet) ID() Address { // Bits returns the number of ones (network bits) and zeros (host bits) in the // subnet mask. func (s *Subnet) Bits() (ones int, zeros int) { - for _, b := range []byte(s.mask) { - for i := uint(0); i < 8; i++ { - if b&(1<<i) == 0 { - zeros++ - } else { - ones++ - } - } - } - return + ones = s.mask.Prefix() + return ones, len(s.mask)*8 - ones } // Prefix returns the number of bits before the first host bit. func (s *Subnet) Prefix() int { - for i, b := range []byte(s.mask) { - for j := 7; j >= 0; j-- { - if b&(1<<uint(j)) == 0 { - return i*8 + 7 - j - } - } - } - return len(s.mask) * 8 + return s.mask.Prefix() } // Mask returns the subnet mask. @@ -328,12 +329,12 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error) + Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (uintptr, ControlMessages, *Error) + Peek([][]byte) (int64, ControlMessages, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -352,6 +353,9 @@ type Endpoint interface { // ErrAddressFamilyNotSupported must be returned. Connect(address FullAddress) *Error + // Disconnect disconnects the endpoint from its peer. + Disconnect() *Error + // Shutdown closes the read and/or write end of the endpoint connection // to its peer. Shutdown(flags ShutdownFlags) *Error @@ -403,6 +407,9 @@ type Endpoint interface { // // NOTE: This method is a no-op for sockets other than TCP. ModerateRecvBuf(copied int) + + // IPTables returns the iptables for this endpoint's stack. + IPTables() (iptables.IPTables, error) } // WriteOptions contains options for Endpoint.Write. @@ -563,13 +570,8 @@ type BroadcastOption int // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. type Route struct { - // Destination is the address that must be matched against the masked - // target address to check if this row is viable. - Destination Address - - // Mask specifies which bits of the Destination and the target address - // must match for this row to be viable. - Mask AddressMask + // Destination must contain the target address for this row to be viable. + Destination Subnet // Gateway is the gateway to be used if this row is viable. Gateway Address @@ -578,25 +580,15 @@ type Route struct { NIC NICID } -// Match determines if r is viable for the given destination address. -func (r *Route) Match(addr Address) bool { - if len(addr) != len(r.Destination) { - return false - } - - // Using header.Ipv4Broadcast would introduce an import cycle, so - // we'll use a literal instead. - if addr == "\xff\xff\xff\xff" { - return true - } - - for i := 0; i < len(r.Destination); i++ { - if (addr[i] & r.Mask[i]) != r.Destination[i] { - return false - } +// String implements the fmt.Stringer interface. +func (r Route) String() string { + var out strings.Builder + fmt.Fprintf(&out, "%s", r.Destination) + if len(r.Gateway) > 0 { + fmt.Fprintf(&out, " via %s", r.Gateway) } - - return true + fmt.Fprintf(&out, " nic %d", r.NIC) + return out.String() } // LinkEndpointID represents a data link layer endpoint. @@ -1068,6 +1060,11 @@ type AddressWithPrefix struct { PrefixLen int } +// String implements the fmt.Stringer interface. +func (a AddressWithPrefix) String() string { + return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { @@ -1078,11 +1075,13 @@ type ProtocolAddress struct { AddressWithPrefix AddressWithPrefix } -// danglingEndpointsMu protects access to danglingEndpoints. -var danglingEndpointsMu sync.Mutex +var ( + // danglingEndpointsMu protects access to danglingEndpoints. + danglingEndpointsMu sync.Mutex -// danglingEndpoints tracks all dangling endpoints no longer owned by the app. -var danglingEndpoints = make(map[Endpoint]struct{}) + // danglingEndpoints tracks all dangling endpoints no longer owned by the app. + danglingEndpoints = make(map[Endpoint]struct{}) +) // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index ebb1c1b56..fb3a0a5ee 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -60,12 +60,12 @@ func TestSubnetBits(t *testing.T) { }{ {"\x00", 0, 8}, {"\x00\x00", 0, 16}, - {"\x36", 4, 4}, - {"\x5c", 4, 4}, - {"\x5c\x5c", 8, 8}, - {"\x5c\x36", 8, 8}, - {"\x36\x5c", 8, 8}, - {"\x36\x36", 8, 8}, + {"\x36", 0, 8}, + {"\x5c", 0, 8}, + {"\x5c\x5c", 0, 16}, + {"\x5c\x36", 0, 16}, + {"\x36\x5c", 0, 16}, + {"\x36\x36", 0, 16}, {"\xff", 8, 0}, {"\xff\xff", 16, 0}, } @@ -122,26 +122,6 @@ func TestSubnetCreation(t *testing.T) { } } -func TestRouteMatch(t *testing.T) { - tests := []struct { - d Address - m AddressMask - a Address - want bool - }{ - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - r := Route{Destination: tt.d, Mask: tt.m} - if got := r.Match(tt.a); got != tt.want { - t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want) - } - } -} - func TestAddressString(t *testing.T) { for _, want := range []string{ // Taken from stdlib. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 62182a3e6..d78a162b8 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index ba6671c26..451d3880e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -130,6 +131,11 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -199,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -301,11 +307,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -422,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - nicid := addr.NIC localPort := uint16(0) switch e.state { diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 99b8c4093..c587b96b6 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -63,7 +63,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s if e.state != stateBound && e.state != stateConnected { return @@ -73,7 +78,7 @@ func (e *endpoint) afterLoad() { if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) if err != nil { - panic(*err) + panic(err) } e.id.LocalAddress = e.route.LocalAddress @@ -85,6 +90,6 @@ func (e *endpoint) afterLoad() { e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) if err != nil { - panic(*err) + panic(err) } } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index bc4b255b4..7241f6c19 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b633cd9d8..13e17e2a6 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -168,6 +169,11 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} +// IPTables implements tcpip.Endpoint.IPTables. +func (ep *endpoint) IPTables() (iptables.IPTables, error) { + return ep.stack.IPTables(), nil +} + // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !ep.associated { @@ -201,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -305,7 +311,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -335,24 +341,24 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintpt return 0, nil, tcpip.ErrUnknownProtocol } - return uintptr(len(payloadBytes)), nil, nil + return int64(len(payloadBytes)), nil, nil } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect implements tcpip.Endpoint.Connect. func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - if ep.closed { return tcpip.ErrInvalidEndpointState } @@ -484,7 +490,7 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return nil + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index cb5534d90..168953dec 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -63,19 +63,23 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. - ep.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(ep) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s - // If the endpoint is connected, re-connect via the save/restore stack. + // If the endpoint is connected, re-connect. if ep.connected { var err *tcpip.Error ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) if err != nil { - panic(*err) + panic(err) } } - // If the endpoint is bound, re-bind via the save/restore stack. + // If the endpoint is bound, re-bind. if ep.bound { if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { panic(tcpip.ErrBadLocalAddress) @@ -83,6 +87,6 @@ func (ep *endpoint) afterLoad() { } if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(*err) + panic(err) } } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4cd25e8e2..1ee1a53f8 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 52fd1bfa3..e9c5099ea 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -96,6 +96,17 @@ type listenContext struct { hasher hash.Hash v6only bool netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed + // by the listening endpoint's worker goroutine. + // + // Lock Ordering: listenEP.workerMu -> pendingMu + pendingMu sync.Mutex + // pending is used to wait for all pendingEndpoints to finish when + // a socket is closed. + pending sync.WaitGroup + // pendingEndpoints is a map of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -133,14 +144,15 @@ func decSynRcvdCount() { } // newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stack, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6only: v6only, - netProto: netProto, - listenEP: listenEP, + stack: stk, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + listenEP: listenEP, + pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } rand.Read(l.nonce[0][:]) @@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } + // listenEP is nil when listenContext is used by tcp.Forwarder. + if l.listenEP != nil { + l.listenEP.mu.Lock() + if l.listenEP.state != StateListen { + l.listenEP.mu.Unlock() + return nil, tcpip.ErrConnectionAborted + } + l.addPendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) @@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err := h.execute(); err != nil { ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } return nil, err } ep.mu.Lock() @@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return ep, nil } +func (l *listenContext) addPendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + l.pendingEndpoints[n.id] = n + l.pending.Add(1) + l.pendingMu.Unlock() +} + +func (l *listenContext) removePendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + delete(l.pendingEndpoints, n.id) + l.pending.Done() + l.pendingMu.Unlock() +} + +func (l *listenContext) closeAllPendingEndpoints() { + l.pendingMu.Lock() + for _, n := range l.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + l.pendingMu.Unlock() + l.pending.Wait() +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state, the new endpoint is closed // instead. func (e *endpoint) deliverAccepted(n *endpoint) { - e.mu.RLock() + e.mu.Lock() state := e.state - e.mu.RUnlock() + e.pendingAccepted.Add(1) + defer e.pendingAccepted.Done() + acceptedChan := e.acceptedChan + e.mu.Unlock() if state == StateListen { - e.acceptedChan <- n + acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { n.Close() @@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() return } - + ctx.removePendingEndpoint(n) e.deliverAccepted(n) } @@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections @@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() e.state = StateClose + // close any endpoints in SYN-RCVD state. + ctx.closeAllPendingEndpoints() + // Do cleanup if needed. e.completeWorkerLocked() @@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) }() - e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) - s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) @@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.handleListenSegment(ctx, s) s.decRef() } - synRcvdCount.pending.Wait() close(e.drainDone) <-e.undrain } diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index d9f79e8c5..c54610a87 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) { // Test acceptance. testV4Accept(t, c) } + +func testV4ListenClose(t *testing.T, c *context.Context) { + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 0 + const n = uint16(32) + + // Start listening. + if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + irs := seqnum.Value(789) + for i := uint16(0); i < n; i++ { + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + i, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Each of these ACK's will cause a syn-cookie based connection to be + // accepted and delivered to the listening endpoint. + for i := uint16(0); i < n; i++ { + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + } + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + nep.Close() + c.EP.Close() +} + +func TestV4ListenCloseOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4ListenClose(t, c) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc49c8272..ac927569a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -361,6 +362,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // pendingAccepted is a synchronization primitive used to track number + // of connections that are queued up to be delivered to the accepted + // channel. We use this to ensure that all goroutines blocked on writing + // to the acceptedChan below terminate before we close acceptedChan. + pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -374,7 +381,11 @@ type endpoint struct { // The goroutine drain completion notification channel. drainDone chan struct{} `state:"nosave"` - // The goroutine undrain notification channel. + // The goroutine undrain notification channel. This is currently used as + // a way to block the worker goroutines. Today nothing closes/writes + // this channel and this causes any goroutines waiting on this to just + // block. This is used during save/restore to prevent worker goroutines + // from mutating state as it's being saved. undrain chan struct{} `state:"nosave"` // probe if not nil is invoked on every received segment. It is passed @@ -574,6 +585,34 @@ func (e *endpoint) Close() { e.mu.Unlock() } +// closePendingAcceptableConnections closes all connections that have completed +// handshake but not yet been delivered to the application. +func (e *endpoint) closePendingAcceptableConnectionsLocked() { + done := make(chan struct{}) + // Spin a goroutine up as ranging on e.acceptedChan will just block when + // there are no more connections in the channel. Using a non-blocking + // select does not work as it can potentially select the default case + // even when there are pending writes but that are not yet written to + // the channel. + go func() { + defer close(done) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + }() + // pendingAccepted(see endpoint.deliverAccepted) tracks the number of + // endpoints which have completed handshake but are not yet written to + // the e.acceptedChan. We wait here till the goroutine above can drain + // all such connections from e.acceptedChan. + e.pendingAccepted.Wait() + close(e.acceptedChan) + <-done + e.acceptedChan = nil +} + // cleanupLocked frees all resources associated with the endpoint. It is called // after Close() is called and the worker goroutine (if any) is done with its // work. @@ -581,14 +620,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { - close(e.acceptedChan) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - e.acceptedChan = nil + e.closePendingAcceptableConnectionsLocked() } e.workerCleanup = false @@ -683,6 +715,11 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -740,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { return v, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. - - e.mu.RLock() - defer e.mu.RUnlock() - +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. if !e.state.connected() { switch e.state { case StateError: - return 0, nil, e.hardError + return 0, e.hardError default: - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } } - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil - } - - e.sndBufMu.Lock() - // Check if the connection has already been closed for sends. if e.sndClosed { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } - // Check against the limit. avail := e.sndBufSize - e.sndBufUsed if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + e.sndBufMu.Lock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrWouldBlock + e.mu.RUnlock() + return 0, nil, err } + e.sndBufMu.Unlock() + e.mu.RUnlock() + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil, nil + } + + // Copy in memory without holding sndBufMu so that worker goroutine can + // make progress independent of this operation. v, perr := p.Get(avail) if perr != nil { - e.sndBufMu.Unlock() return 0, nil, perr } - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a + // write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due to a + // simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } // Add data to the send queue. + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) e.sndBufUsed += l e.sndBufInQueue += seqnum.Size(l) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() if e.workMu.TryLock() { // Do the work inline. @@ -803,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return uintptr(l), nil, nil + return int64(l), nil, nil } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -835,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // Make a copy of vec so we can modify the slide headers. vec = append([][]byte(nil), vec...) - var num uintptr - + var num int64 for s := e.rcvList.Front(); s != nil; s = s.Next() { views := s.data.Views() @@ -855,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er n := copy(vec[0], v) v = v[n:] vec[0] = vec[0][n:] - num += uintptr(n) + num += int64(n) } } } @@ -1277,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } } @@ -1291,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" && addr.Port == 0 { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - return e.connect(addr, true, true) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b3f0f6c5d..831389ec7 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -165,7 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() @@ -197,14 +202,13 @@ func (e *endpoint) afterLoad() { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { + e.connectingAddress = e.id.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress - } else { - e.connectingAddress = e.id.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 0fee7ab72..735edfe55 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -39,6 +39,28 @@ const ( nDupAckThreshold = 3 ) +// ccState indicates the current congestion control state for this sender. +type ccState int + +const ( + // Open indicates that the sender is receiving acks in order and + // no loss or dupACK's etc have been detected. + Open ccState = iota + // RTORecovery indicates that an RTO has occurred and the sender + // has entered an RTO based recovery phase. + RTORecovery + // FastRecovery indicates that the sender has entered FastRecovery + // based on receiving nDupAck's. This state is entered only when + // SACK is not in use. + FastRecovery + // SACKRecovery indicates that the sender has entered SACK based + // recovery. + SACKRecovery + // Disorder indicates the sender either received some SACK blocks + // or dupACK's. + Disorder +) + // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { @@ -138,6 +160,9 @@ type sender struct { // maxSentAck is the maxium acknowledgement actually sent. maxSentAck seqnum.Value + // state is the current state of congestion control for this endpoint. + state ccState + // cc is the congestion control algorithm in use for this sender. cc congestionControl } @@ -435,6 +460,7 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveFastRecovery() } + s.state = RTORecovery s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -638,7 +664,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se segEnd = seg.sequenceNumber.Add(1) // Transition to FIN-WAIT1 state since we're initiating an active close. s.ep.mu.Lock() - s.ep.state = StateFinWait1 + switch s.ep.state { + case StateCloseWait: + // We've already received a FIN and are now sending our own. The + // sender is now awaiting a final ACK for this FIN. + s.ep.state = StateLastAck + default: + s.ep.state = StateFinWait1 + } s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. @@ -820,9 +853,11 @@ func (s *sender) enterFastRecovery() { s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding if s.ep.sackPermitted { + s.state = SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() return } + s.state = FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } @@ -981,6 +1016,7 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() + s.state = Disorder return false } @@ -1112,6 +1148,9 @@ func (s *sender) handleRcvdSegment(seg *segment) { // window based on the number of acknowledged packets. if !s.fr.active { s.cc.Update(originalOutstanding - s.outstanding) + if s.fr.last.LessThan(s.sndUna) { + s.state = Open + } } // It is possible for s.outstanding to drop below zero if we get diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 915a98047..f79b8ec5f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2874,15 +2874,11 @@ func makeStack() (*stack.Stack, *tcpip.Error) { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index bcc0f3e28..272481aa0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -168,15 +168,11 @@ func New(t *testing.T, mtu uint32) *Context { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 6dac66b50..ac2666f69 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91f89a781..ac5905772 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -172,6 +173,11 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -241,13 +247,13 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto, err := e.checkV4Mapped(&addr, false) - if err != nil { - return stack.Route{}, 0, 0, err +func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { + localAddr := e.id.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" } - localAddr := e.id.LocalAddress if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { nicid = e.multicastNICID @@ -260,14 +266,14 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac // Find a route to the desired destination. r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { - return stack.Route{}, 0, 0, err + return stack.Route{}, 0, err } - return r, nicid, netProto, nil + return r, nicid, nil } // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -336,7 +342,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - r, _, _, err := e.connectRoute(nicid, *to) + netProto, err := e.checkV4Mapped(to, false) + if err != nil { + return 0, nil, err + } + + r, _, err := e.connectRoute(nicid, *to, netProto) if err != nil { return 0, nil, err } @@ -368,11 +379,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -442,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { if nicID == 0 { r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) if err == nil { @@ -686,7 +702,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } @@ -705,7 +721,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } -func (e *endpoint) disconnect() *tcpip.Error { +// Disconnect implements tcpip.Endpoint.Disconnect. +func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -740,8 +757,9 @@ func (e *endpoint) disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" { - return e.disconnect() + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err } if addr.Port == 0 { // We don't support connecting to port zero. @@ -770,7 +788,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - r, nicid, netProto, err := e.connectRoute(nicid, addr) + r, nicid, err := e.connectRoute(nicid, addr, netProto) if err != nil { return err } @@ -906,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicid := addr.NIC - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicid == 0 { return tcpip.ErrBadLocalAddress @@ -1056,3 +1074,7 @@ func (e *endpoint) State() uint32 { // TODO(b/112063468): Translate internal state to values returned by Linux. return 0 } + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 18e786397..5cbb56120 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -64,7 +64,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s for _, m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { @@ -90,9 +95,10 @@ func (e *endpoint) afterLoad() { if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) if err != nil { - panic(*err) + panic(err) } - } else if len(e.id.LocalAddress) != 0 { // stateBound + } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } @@ -105,6 +111,6 @@ func (e *endpoint) afterLoad() { e.id.LocalPort = 0 e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) if err != nil { - panic(*err) + panic(err) } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 56c285f88..9da6edce2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -34,13 +35,19 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Addresses and ports used for testing. It is recommended that tests stick to +// using these addresses as it allows using the testFlow helper. +// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' +// represents the remote endpoint. const ( + v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + stackV4MappedAddr = v4MappedAddrPrefix + stackAddr + testV4MappedAddr = v4MappedAddrPrefix + testAddr + multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr + broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr + v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" stackAddr = "\x0a\x00\x00\x01" stackPort = 1234 @@ -48,7 +55,7 @@ const ( testPort = 4096 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - multicastPort = 1234 + broadcastAddr = header.IPv4Broadcast // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -56,6 +63,205 @@ const ( defaultMTU = 65536 ) +// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in +// a packet header. These values are used to populate a header or verify one. +// Note that because they are used in packet headers, the addresses are never in +// a V4-mapped format. +type header4Tuple struct { + srcAddr tcpip.FullAddress + dstAddr tcpip.FullAddress +} + +// testFlow implements a helper type used for sending and receiving test +// packets. A given test flow value defines 1) the socket endpoint used for the +// test and 2) the type of packet send or received on the endpoint. E.g., a +// multicastV6Only flow is a V6 multicast packet passing through a V6-only +// endpoint. The type provides helper methods to characterize the flow (e.g., +// isV4) as well as return a proper header4Tuple for it. +type testFlow int + +const ( + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket +) + +func (flow testFlow) String() string { + switch flow { + case unicastV4: + return "unicastV4" + case unicastV6: + return "unicastV6" + case unicastV6Only: + return "unicastV6Only" + case unicastV4in6: + return "unicastV4in6" + case multicastV4: + return "multicastV4" + case multicastV6: + return "multicastV6" + case multicastV6Only: + return "multicastV6Only" + case multicastV4in6: + return "multicastV4in6" + case broadcast: + return "broadcast" + case broadcastIn6: + return "broadcastIn6" + default: + return "unknown" + } +} + +// packetDirection explains if a flow is incoming (read) or outgoing (write). +type packetDirection int + +const ( + incoming packetDirection = iota + outgoing +) + +// header4Tuple returns the header4Tuple for the given flow and direction. Note +// that the tuple contains no mapped addresses as those only exist at the socket +// level but not at the packet header level. +func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { + var h header4Tuple + if flow.isV4() { + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastAddr + } else if flow.isBroadcast() { + h.dstAddr.Addr = broadcastAddr + } + } else { // IPv6 + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastV6Addr + } + } + return h +} + +func (flow testFlow) getMcastAddr() tcpip.Address { + if flow.isV4() { + return multicastAddr + } + return multicastV6Addr +} + +// mapAddrIfApplicable converts the given V4 address into its V4-mapped version +// if it is applicable to the flow. +func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { + if flow.isMapped() { + return v4MappedAddrPrefix + v4Addr + } + return v4Addr +} + +// netProto returns the protocol number used for the network packet. +func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { + if flow.isV4() { + return ipv4.ProtocolNumber + } + return ipv6.ProtocolNumber +} + +// sockProto returns the protocol number used when creating the socket +// endpoint for this flow. +func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { + switch flow { + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + return ipv6.ProtocolNumber + case unicastV4, multicastV4, broadcast: + return ipv4.ProtocolNumber + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { + if flow.isV4() { + return checker.IPv4 + } + return checker.IPv6 +} + +func (flow testFlow) isV6() bool { return !flow.isV4() } +func (flow testFlow) isV4() bool { + return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() +} + +func (flow testFlow) isV6Only() bool { + switch flow { + case unicastV6Only, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMulticast() bool { + switch flow { + case multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isBroadcast() bool { + switch flow { + case broadcast, broadcastIn6: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMapped() bool { + switch flow { + case unicastV4in6, multicastV4in6, broadcastIn6: + return true + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -65,12 +271,9 @@ type testContext struct { wq waiter.Queue } -type headers struct { - srcPort uint16 - dstPort uint16 -} - func newDualTestContext(t *testing.T, mtu uint32) *testContext { + t.Helper() + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) id, linkEP := channel.New(256, mtu, "") @@ -91,15 +294,11 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) @@ -117,51 +316,54 @@ func (c *testContext) cleanup() { } } -func (c *testContext) createV6Endpoint(v6only bool) { +func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { + c.t.Helper() + var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + c.t.Fatal("NewEndpoint failed: ", err) } +} - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) +func (c *testContext) createEndpointForFlow(flow testFlow) { + c.t.Helper() + + c.createEndpoint(flow.sockProto()) + if flow.isV6Only() { + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + } else if flow.isBroadcast() { + if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } } } -func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { +// getPacketAndVerify reads a packet from the link endpoint and verifies the +// header against expected values from the given test flow. In addition, it +// calls any extra checker functions provided. +func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { + c.t.Helper() + select { case p := <-c.linkEP.C: - if p.Proto != protocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) - var srcAddr, dstAddr tcpip.Address - switch protocolNumber { - case ipv4.ProtocolNumber: - checkerFn = checker.IPv4 - srcAddr, dstAddr = stackAddr, testAddr - if multicast { - dstAddr = multicastAddr - } - case ipv6.ProtocolNumber: - checkerFn = checker.IPv6 - srcAddr, dstAddr = stackV6Addr, testV6Addr - if multicast { - dstAddr = multicastV6Addr - } - default: - c.t.Fatalf("unknown protocol %d", protocolNumber) - } - checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) + h := flow.header4Tuple(outgoing) + checkers := append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) return b case <-time.After(2 * time.Second): @@ -171,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult return nil } -func (c *testContext) sendV6Packet(payload []byte, h *headers) { +// injectPacket creates a packet of the given flow and with the given payload, +// and injects it into the link endpoint. +func (c *testContext) injectPacket(flow testFlow, payload []byte) { + c.t.Helper() + + h := flow.header4Tuple(incoming) + if flow.isV4() { + c.injectV4Packet(payload, &h) + } else { + c.injectV6Packet(payload, &h) + } +} + +// injectV6Packet creates a V6 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -182,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -205,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -func (c *testContext) sendPacket(payload []byte, h *headers) { +// injectV6Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -217,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) { TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testAddr, - DstAddr: stackAddr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) ip.SetChecksum(^ip.CalculateChecksum()) // Initialize the UDP header. u := header.UDP(buf[header.IPv4MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -253,7 +472,7 @@ func TestBindPortReuse(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) var eps [5]tcpip.Endpoint reusePortOpt := tcpip.ReusePortOption(1) @@ -296,9 +515,9 @@ func TestBindPortReuse(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort + port, - dstPort: stackPort, + c.injectV6Packet(payload, &header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, }) var addr tcpip.FullAddress @@ -333,13 +552,14 @@ func TestBindPortReuse(t *testing.T) { } } -func testV4Read(c *testContext) { - // Send a packet. +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + payload := newPayload() - c.sendPacket(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) + c.injectPacket(flow, payload) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -363,8 +583,9 @@ func testV4Read(c *testContext) { } // Check the peer address. - if addr.Addr != testAddr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + h := flow.header4Tuple(incoming) + if addr.Addr != h.srcAddr.Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -377,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { t.Fatalf("ep.Bind(...) failed: %v", err) @@ -388,7 +609,7 @@ func TestBindReservedPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) @@ -447,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -455,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil { + if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -485,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV6ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - // Send a packet. - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testV6Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } + // Test acceptance. + testRead(c, unicastV6) } func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - // Create v4 UDP endpoint. - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpointForFlow(unicastV4) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -555,62 +736,123 @@ func TestV4ReadOnV4(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4) } -func testV4Write(c *testContext) uint16 { - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) +// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast +// address and receive data sent to that address. +func TestReadOnBoundToMulticast(t *testing.T) { + // FIXME(b/128189410): multicastV4in6 currently doesn't work as + // AddMembershipOption doesn't handle V4in6 addresses. + for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to multicast address. + mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) + if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + // Join multicast group. + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } + + testRead(c, flow) + }) } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast +// address and receive broadcast data on it. +func TestV4ReadOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to broadcast address. + bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) + if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// testFailingWrite sends a packet of the given test flow into the UDP endpoint +// and verifies it fails with the provided error code. +func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { + c.t.Helper() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + + payload := buffer.View(newPayload()) + _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + }) + if gotErr != wantErr { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } +} - return udp.SourcePort() +// testWrite sends a packet of the given test flow from the UDP endpoint to the +// flow's destination address:port. It then receives it from the link endpoint +// and verifies its correctness including any additional checker functions +// provided. +func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, true, checkers...) } -func testV6Write(c *testContext) uint16 { - // Write to v6 address. +// testWriteWithoutDestination sends a packet of the given test flow from the +// UDP endpoint without giving a destination address:port. It then receives it +// from the link endpoint and verifies its correctness including any additional +// checker functions provided. +func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, false, checkers...) +} + +func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + + writeOpts := tcpip.WriteOptions{} + if setDest { + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + writeOpts = tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + } + } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %v", err) } - if n != uintptr(len(payload)) { + if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. + // Received the packet and check the payload. + b := c.getPacketAndVerify(flow, checkers...) + var udp header.UDP + if flow.isV4() { + udp = header.UDP(header.IPv4(b).Payload()) + } else { + udp = header.UDP(header.IPv6(b).Payload()) + } if !bytes.Equal(payload, udp.Payload()) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } @@ -619,8 +861,10 @@ func testV6Write(c *testContext) uint16 { } func testDualWrite(c *testContext) uint16 { - v4Port := testV4Write(c) - v6Port := testV6Write(c) + c.t.Helper() + + v4Port := testWrite(c, unicastV4in6) + v6Port := testWrite(c, unicastV6) if v4Port != v6Port { c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) } @@ -632,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) } @@ -641,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -658,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV6Write(c) + testWrite(c, unicastV6) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNetworkUnreachable { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) } func TestDualWriteConnectedToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Write(c) + testWrite(c, unicastV4in6) // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV4WriteOnV6Only(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(true) + c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -728,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV6WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } + testWriteWithoutDestination(c, unicastV6) } func TestV4WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) + testWriteWithoutDestination(c, unicastV4) +} + +// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket +// that is bound to a V4 multicast address. +func TestWriteOnBoundToV4Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a +// socket that is bound to a V4-mapped multicast address. +func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV6, multicastV6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// V6-only socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToBroadcast checks that we can send packets out of a +// socket that is bound to the broadcast address. +func TestWriteOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 broadcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a +// socket that is bound to the V4-mapped broadcast address. +func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } } @@ -814,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { defer c.cleanup() // Create IPv4 UDP endpoint - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Read(c) + testRead(c, unicastV4) var want uint64 = 1 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { @@ -837,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) @@ -847,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } -func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) { - for _, name := range []string{"v4", "v6", "dual"} { - t.Run(name, func(t *testing.T) { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch name { - case "v4": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } +func TestTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - var variants []string - switch name { - case "v4": - variants = []string{"v4"} - case "v6": - variants = []string{"v6"} - case "dual": - variants = []string{"v6", "mapped"} - } + c.createEndpointForFlow(flow) - for _, variant := range variants { - t.Run(variant, func(t *testing.T) { - optFunc(t, name, networkProtocolNumber, variant) - }) + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) } - }) - } -} -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") + var wantTTL uint8 + if flow.isMulticast() { + wantTTL = multicastTTL + } else { + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + t.Fatal(err) } + wantTTL = ep.DefaultTTL() + ep.Close() + } - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } +} - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") - } +func TestMulticastInterfaceOption(t *testing.T) { + for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + h := flow.header4Tuple(outgoing) + mcastAddr := h.dstAddr.Addr + localIfAddr := h.srcAddr.Addr + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(flow.sockProto()) + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: flow.mapAddrIfApplicable(mcastAddr), + Port: stackPort, + } + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + } - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) -} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } -func TestMulticastInterfaceOption(t *testing.T) { - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - var mcastAddr, localIfAddr tcpip.Address - switch variant { - case "v4": - mcastAddr = multicastAddr - localIfAddr = stackAddr - case "mapped": - mcastAddr = multicastV4MappedAddr - localIfAddr = stackAddr - case "v6": - mcastAddr = multicastV6Addr - localIfAddr = stackV6Addr - default: - t.Fatal("unknown test variant") - } - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: mcastAddr, - Port: multicastPort, + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %v", err) } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) + }) + } + }) + } + }) + } } diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 769509e80..cbd92fc05 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -11,8 +11,8 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/unet", visibility = ["//visibility:public"], deps = [ - "//pkg/abi/linux", "//pkg/gate", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go index f8a42c914..85ef46edf 100644 --- a/pkg/unet/unet_unsafe.go +++ b/pkg/unet/unet_unsafe.go @@ -16,12 +16,11 @@ package unet import ( "io" - "math" "sync/atomic" "syscall" "unsafe" - "gvisor.dev/gvisor/pkg/abi/linux" + "golang.org/x/sys/unix" ) // wait blocks until the socket FD is ready for reading or writing, depending @@ -37,23 +36,23 @@ func (s *Socket) wait(write bool) error { return errClosing } - events := []linux.PollFD{ + events := []unix.PollFd{ { // The actual socket FD. - FD: fd, - Events: linux.POLLIN, + Fd: fd, + Events: unix.POLLIN, }, { // The eventfd, signaled when we are closing. - FD: int32(s.efd), - Events: linux.POLLIN, + Fd: int32(s.efd), + Events: unix.POLLIN, }, } if write { - events[0].Events = linux.POLLOUT + events[0].Events = unix.POLLOUT } - _, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&events[0])), 2, uintptr(math.MaxUint64)) + _, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(&events[0])), 2, 0, 0, 0, 0) if e == syscall.EINTR { continue } @@ -61,7 +60,7 @@ func (s *Socket) wait(write bool) error { return e } - if events[1].REvents&linux.POLLIN == linux.POLLIN { + if events[1].Revents&unix.POLLIN == unix.POLLIN { // eventfd signaled, we're closing. return errClosing } |