From 26adb3c4747288aba2475cb403d66d68481793dc Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Wed, 28 Apr 2021 17:00:47 -0700 Subject: Automated rollback of changelist 369686285 PiperOrigin-RevId: 371015541 --- pkg/p9/client_file.go | 16 ++++++++++ pkg/p9/file.go | 60 +++++++++++++++++++++++++++++++++++- pkg/p9/handlers.go | 28 +++++++++++++++++ pkg/p9/messages.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++++--- pkg/p9/p9.go | 28 +++++++++++++++++ pkg/p9/version.go | 8 ++++- 6 files changed, 218 insertions(+), 6 deletions(-) (limited to 'pkg/p9') diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 7abc82e1b..28396b0ea 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -121,6 +121,22 @@ func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, At return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil } +func (c *clientFile) MultiGetAttr(names []string) ([]FullStat, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, unix.EBADF + } + + if !versionSupportsTmultiGetAttr(c.client.version) { + return DefaultMultiGetAttr(c, names) + } + + rmultigetattr := Rmultigetattr{} + if err := c.client.sendRecv(&Tmultigetattr{FID: c.fid, Names: names}, &rmultigetattr); err != nil { + return nil, err + } + return rmultigetattr.Stats, nil +} + // StatFS implements File.StatFS. func (c *clientFile) StatFS() (FSStat, error) { if atomic.LoadUint32(&c.closed) != 0 { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index c59c6a65b..97e0231d6 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -15,6 +15,8 @@ package p9 import ( + "errors" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/fd" ) @@ -72,6 +74,15 @@ type File interface { // On the server, WalkGetAttr has a read concurrency guarantee. WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) + // MultiGetAttr batches up multiple calls to GetAttr(). names is a list of + // path components similar to Walk(). If the first component name is empty, + // the current file is stat'd and included in the results. If the walk reaches + // a file that doesn't exist or not a directory, MultiGetAttr returns the + // partial result with no error. + // + // On the server, MultiGetAttr has a read concurrency guarantee. + MultiGetAttr(names []string) ([]FullStat, error) + // StatFS returns information about the file system associated with // this file. // @@ -306,6 +317,53 @@ func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { type DisallowServerCalls struct{} // Renamed implements File.Renamed. -func (*clientFile) Renamed(File, string) { +func (*DisallowServerCalls) Renamed(File, string) { panic("Renamed should not be called on the client") } + +// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File. +func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { + stats := make([]FullStat, 0, len(names)) + parent := start + mask := AttrMaskAll() + for i, name := range names { + if len(name) == 0 && i == 0 { + qid, valid, attr, err := parent.GetAttr(mask) + if err != nil { + return nil, err + } + stats = append(stats, FullStat{ + QID: qid, + Valid: valid, + Attr: attr, + }) + continue + } + qids, child, valid, attr, err := parent.WalkGetAttr([]string{name}) + if parent != start { + _ = parent.Close() + } + if err != nil { + if errors.Is(err, unix.ENOENT) { + return stats, nil + } + return nil, err + } + stats = append(stats, FullStat{ + QID: qids[0], + Valid: valid, + Attr: attr, + }) + if attr.Mode.FileType() != ModeDirectory { + // Doesn't need to continue if entry is not a dir. Including symlinks + // that cannot be followed. + _ = child.Close() + break + } + parent = child + } + if parent != start { + _ = parent.Close() + } + return stats, nil +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 58312d0cc..758e11b13 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -1421,3 +1421,31 @@ func (t *Tchannel) handle(cs *connState) message { } return rchannel } + +// handle implements handler.handle. +func (t *Tmultigetattr) handle(cs *connState) message { + for i, name := range t.Names { + if len(name) == 0 && i == 0 { + // Empty name is allowed on the first entry to indicate that the current + // FID needs to be included in the result. + continue + } + if err := checkSafeName(name); err != nil { + return newErr(err) + } + } + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(unix.EBADF) + } + defer ref.DecRef() + + var stats []FullStat + if err := ref.safelyRead(func() (err error) { + stats, err = ref.file.MultiGetAttr(t.Names) + return err + }); err != nil { + return newErr(err) + } + return &Rmultigetattr{Stats: stats} +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index cf13cbb69..2ff4694c0 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -254,8 +254,8 @@ func (r *Rwalk) decode(b *buffer) { // encode implements encoder.encode. func (r *Rwalk) encode(b *buffer) { b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2243,8 +2243,8 @@ func (r *Rwalkgetattr) encode(b *buffer) { r.Valid.encode(b) r.Attr.encode(b) b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2552,6 +2552,80 @@ func (r *Rchannel) String() string { return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } +// Tmultigetattr is a multi-getattr request. +type Tmultigetattr struct { + // FID is the FID to be walked. + FID FID + + // Names are the set of names to be walked. + Names []string +} + +// decode implements encoder.decode. +func (t *Tmultigetattr) decode(b *buffer) { + t.FID = b.ReadFID() + n := b.Read16() + t.Names = t.Names[:0] + for i := 0; i < int(n); i++ { + t.Names = append(t.Names, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (t *Tmultigetattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write16(uint16(len(t.Names))) + for _, name := range t.Names { + b.WriteString(name) + } +} + +// Type implements message.Type. +func (*Tmultigetattr) Type() MsgType { + return MsgTmultigetattr +} + +// String implements fmt.Stringer. +func (t *Tmultigetattr) String() string { + return fmt.Sprintf("Tmultigetattr{FID: %d, Names: %v}", t.FID, t.Names) +} + +// Rmultigetattr is a multi-getattr response. +type Rmultigetattr struct { + // Stats are the set of FullStat returned for each of the names in the + // request. + Stats []FullStat +} + +// decode implements encoder.decode. +func (r *Rmultigetattr) decode(b *buffer) { + n := b.Read16() + r.Stats = r.Stats[:0] + for i := 0; i < int(n); i++ { + var fs FullStat + fs.decode(b) + r.Stats = append(r.Stats, fs) + } +} + +// encode implements encoder.encode. +func (r *Rmultigetattr) encode(b *buffer) { + b.Write16(uint16(len(r.Stats))) + for i := range r.Stats { + r.Stats[i].encode(b) + } +} + +// Type implements message.Type. +func (*Rmultigetattr) Type() MsgType { + return MsgRmultigetattr +} + +// String implements fmt.Stringer. +func (r *Rmultigetattr) String() string { + return fmt.Sprintf("Rmultigetattr{Stats: %v}", r.Stats) +} + const maxCacheSize = 3 // msgFactory is used to reduce allocations by caching messages for reuse. @@ -2717,6 +2791,8 @@ func init() { msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) + msgRegistry.register(MsgTmultigetattr, func() message { return &Tmultigetattr{} }) + msgRegistry.register(MsgRmultigetattr, func() message { return &Rmultigetattr{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 648cf4b49..3d452a0bd 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -402,6 +402,8 @@ const ( MsgRallocate MsgType = 139 MsgTsetattrclunk MsgType = 140 MsgRsetattrclunk MsgType = 141 + MsgTmultigetattr MsgType = 142 + MsgRmultigetattr MsgType = 143 MsgTchannel MsgType = 250 MsgRchannel MsgType = 251 ) @@ -1178,3 +1180,29 @@ func (a *AllocateMode) encode(b *buffer) { } b.Write32(mask) } + +// FullStat is used in the result of a MultiGetAttr call. +type FullStat struct { + QID QID + Valid AttrMask + Attr Attr +} + +// String implements fmt.Stringer. +func (f *FullStat) String() string { + return fmt.Sprintf("FullStat{QID: %v, Valid: %v, Attr: %v}", f.QID, f.Valid, f.Attr) +} + +// decode implements encoder.decode. +func (f *FullStat) decode(b *buffer) { + f.QID.decode(b) + f.Valid.decode(b) + f.Attr.decode(b) +} + +// encode implements encoder.encode. +func (f *FullStat) encode(b *buffer) { + f.QID.encode(b) + f.Valid.encode(b) + f.Attr.encode(b) +} diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 8d7168ef5..950236162 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 12 + highestSupportedVersion uint32 = 13 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -179,3 +179,9 @@ func versionSupportsListRemoveXattr(v uint32) bool { func versionSupportsTsetattrclunk(v uint32) bool { return v >= 12 } + +// versionSupportsTmultiGetAttr returns true if version v supports +// the TmultiGetAttr message. +func versionSupportsTmultiGetAttr(v uint32) bool { + return v >= 13 +} -- cgit v1.2.3