diff options
author | Fabricio Voznika <fvoznika@google.com> | 2019-04-29 15:32:45 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-29 15:33:47 -0700 |
commit | ddab854b9a895603664fa4abfa525f6a29047083 (patch) | |
tree | 349b428a042c7e7e0d7de71e56a319f7df3ae29b /pkg/p9 | |
parent | 4d52a5520101a88424fb63dd99412a1db33fbd06 (diff) |
Reduce memory allocations on serving path
Cache last used messages and reuse them for subsequent requests.
If more messages are needed, they are created outside the cache
on demand.
PiperOrigin-RevId: 245836910
Change-Id: Icf099ddff95df420db8e09f5cdd41dcdce406c61
Diffstat (limited to 'pkg/p9')
-rw-r--r-- | pkg/p9/buffer.go | 3 | ||||
-rw-r--r-- | pkg/p9/client.go | 6 | ||||
-rw-r--r-- | pkg/p9/messages.go | 232 | ||||
-rw-r--r-- | pkg/p9/messages_test.go | 46 | ||||
-rw-r--r-- | pkg/p9/server.go | 4 | ||||
-rw-r--r-- | pkg/p9/transport_test.go | 12 |
6 files changed, 187 insertions, 116 deletions
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go index 4c8c6555d..249536d8a 100644 --- a/pkg/p9/buffer.go +++ b/pkg/p9/buffer.go @@ -20,7 +20,8 @@ import ( // encoder is used for messages and 9P primitives. type encoder interface { - // Decode decodes from the given buffer. + // Decode decodes from the given buffer. Decode may be called more than once + // to reuse the instance. It must clear any previous state. // // This may not fail, exhaustion will be recorded in the buffer. Decode(b *buffer) diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 2f9c716d0..56587e2cf 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -110,16 +110,16 @@ type Client struct { // You should not use the same socket for multiple clients. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. - if messageSize <= largestFixedSize { + if messageSize <= msgRegistry.largestFixedSize { return nil, &ErrMessageTooLarge{ size: messageSize, - msize: largestFixedSize, + msize: msgRegistry.largestFixedSize, } } // Compute a payload size and round to 512 (normal block size) // if it's larger than a single block. - payloadSize := messageSize - largestFixedSize + payloadSize := messageSize - msgRegistry.largestFixedSize if payloadSize > 512 && payloadSize%512 != 0 { payloadSize -= (payloadSize % 512) } diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 833defbd6..3c7898cc1 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -193,6 +193,7 @@ func (t *Twalk) Decode(b *buffer) { t.FID = b.ReadFID() t.NewFID = b.ReadFID() n := b.Read16() + t.Names = t.Names[:0] for i := 0; i < int(n); i++ { t.Names = append(t.Names, b.ReadString()) } @@ -227,6 +228,7 @@ type Rwalk struct { // Decode implements encoder.Decode. func (r *Rwalk) Decode(b *buffer) { n := b.Read16() + r.QIDs = r.QIDs[:0] for i := 0; i < int(n); i++ { var q QID q.Decode(b) @@ -1608,6 +1610,7 @@ type Rreaddir struct { func (r *Rreaddir) Decode(b *buffer) { r.Count = b.Read32() entriesBuf := buffer{data: r.payload} + r.Entries = r.Entries[:0] for { var d Dirent d.Decode(&entriesBuf) @@ -1827,6 +1830,7 @@ func (t *Twalkgetattr) Decode(b *buffer) { t.FID = b.ReadFID() t.NewFID = b.ReadFID() n := b.Read16() + t.Names = t.Names[:0] for i := 0; i < int(n); i++ { t.Names = append(t.Names, b.ReadString()) } @@ -1869,6 +1873,7 @@ func (r *Rwalkgetattr) Decode(b *buffer) { r.Valid.Decode(b) r.Attr.Decode(b) n := b.Read16() + r.QIDs = r.QIDs[:0] for i := 0; i < int(n); i++ { var q QID q.Decode(b) @@ -2139,34 +2144,80 @@ func (r *Rlconnect) String() string { return fmt.Sprintf("Rlconnect{File: %v}", r.File) } -// messageRegistry indexes all messages by type. -var messageRegistry = make([]func() message, math.MaxUint8) +const maxCacheSize = 3 -// messageByType creates a new message by type. +// msgFactory is used to reduce allocations by caching messages for reuse. +type msgFactory struct { + create func() message + cache chan message +} + +// msgRegistry indexes all message factories by type. +var msgRegistry registry + +type registry struct { + factories [math.MaxUint8]msgFactory + + // largestFixedSize is computed so that given some message size M, you can + // compute the maximum payload size (e.g. for Twrite, Rread) with + // M-largestFixedSize. You could do this individual on a per-message basis, + // but it's easier to compute a single maximum safe payload. + largestFixedSize uint32 +} + +// get returns a new message by type. // // An error is returned in the case of an unknown message. // // This takes, and ignores, a message tag so that it may be used directly as a // lookupTagAndType function for recv (by design). -func messageByType(_ Tag, t MsgType) (message, error) { - fn := messageRegistry[t] - if fn == nil { +func (r *registry) get(_ Tag, t MsgType) (message, error) { + entry := &r.factories[t] + if entry.create == nil { return nil, &ErrInvalidMsgType{t} } - return fn(), nil + + select { + case msg := <-entry.cache: + return msg, nil + default: + return entry.create(), nil + } +} + +func (r *registry) put(msg message) { + if p, ok := msg.(payloader); ok { + p.SetPayload(nil) + } + if f, ok := msg.(filer); ok { + f.SetFilePayload(nil) + } + + entry := &r.factories[msg.Type()] + select { + case entry.cache <- msg: + default: + } } // register registers the given message type. // // This may cause panic on failure and should only be used from init. -func register(t MsgType, fn func() message) { - if int(t) >= len(messageRegistry) { - panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(messageRegistry))) +func (r *registry) register(t MsgType, fn func() message) { + if int(t) >= len(r.factories) { + panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories))) } - if messageRegistry[t] != nil { - panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, messageRegistry[t](), fn())) + if r.factories[t].create != nil { + panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn())) + } + r.factories[t] = msgFactory{ + create: fn, + cache: make(chan message, maxCacheSize), + } + + if size := calculateSize(fn()); size > r.largestFixedSize { + r.largestFixedSize = size } - messageRegistry[t] = fn } func calculateSize(m message) uint32 { @@ -2178,93 +2229,72 @@ func calculateSize(m message) uint32 { return uint32(len(dataBuf.data)) } -// largestFixedSize is computed within calculateLargestSize. -// -// This is computed so that given some message size M, you can compute -// the maximum payload size (e.g. for Twrite, Rread) with M-largestFixedSize. -// You could do this individual on a per-message basis, but it's easier to -// compute a single maximum safe payload. -var largestFixedSize uint32 - -// calculateLargestFixedSize is called from within init. -func calculateLargestFixedSize() { - for _, fn := range messageRegistry { - if fn != nil { - if size := calculateSize(fn()); size > largestFixedSize { - largestFixedSize = size - } - } - } -} - func init() { - register(MsgRlerror, func() message { return &Rlerror{} }) - register(MsgTstatfs, func() message { return &Tstatfs{} }) - register(MsgRstatfs, func() message { return &Rstatfs{} }) - register(MsgTlopen, func() message { return &Tlopen{} }) - register(MsgRlopen, func() message { return &Rlopen{} }) - register(MsgTlcreate, func() message { return &Tlcreate{} }) - register(MsgRlcreate, func() message { return &Rlcreate{} }) - register(MsgTsymlink, func() message { return &Tsymlink{} }) - register(MsgRsymlink, func() message { return &Rsymlink{} }) - register(MsgTmknod, func() message { return &Tmknod{} }) - register(MsgRmknod, func() message { return &Rmknod{} }) - register(MsgTrename, func() message { return &Trename{} }) - register(MsgRrename, func() message { return &Rrename{} }) - register(MsgTreadlink, func() message { return &Treadlink{} }) - register(MsgRreadlink, func() message { return &Rreadlink{} }) - register(MsgTgetattr, func() message { return &Tgetattr{} }) - register(MsgRgetattr, func() message { return &Rgetattr{} }) - register(MsgTsetattr, func() message { return &Tsetattr{} }) - register(MsgRsetattr, func() message { return &Rsetattr{} }) - register(MsgTxattrwalk, func() message { return &Txattrwalk{} }) - register(MsgRxattrwalk, func() message { return &Rxattrwalk{} }) - register(MsgTxattrcreate, func() message { return &Txattrcreate{} }) - register(MsgRxattrcreate, func() message { return &Rxattrcreate{} }) - register(MsgTreaddir, func() message { return &Treaddir{} }) - register(MsgRreaddir, func() message { return &Rreaddir{} }) - register(MsgTfsync, func() message { return &Tfsync{} }) - register(MsgRfsync, func() message { return &Rfsync{} }) - register(MsgTlink, func() message { return &Tlink{} }) - register(MsgRlink, func() message { return &Rlink{} }) - register(MsgTmkdir, func() message { return &Tmkdir{} }) - register(MsgRmkdir, func() message { return &Rmkdir{} }) - register(MsgTrenameat, func() message { return &Trenameat{} }) - register(MsgRrenameat, func() message { return &Rrenameat{} }) - register(MsgTunlinkat, func() message { return &Tunlinkat{} }) - register(MsgRunlinkat, func() message { return &Runlinkat{} }) - register(MsgTversion, func() message { return &Tversion{} }) - register(MsgRversion, func() message { return &Rversion{} }) - register(MsgTauth, func() message { return &Tauth{} }) - register(MsgRauth, func() message { return &Rauth{} }) - register(MsgTattach, func() message { return &Tattach{} }) - register(MsgRattach, func() message { return &Rattach{} }) - register(MsgTflush, func() message { return &Tflush{} }) - register(MsgRflush, func() message { return &Rflush{} }) - register(MsgTwalk, func() message { return &Twalk{} }) - register(MsgRwalk, func() message { return &Rwalk{} }) - register(MsgTread, func() message { return &Tread{} }) - register(MsgRread, func() message { return &Rread{} }) - register(MsgTwrite, func() message { return &Twrite{} }) - register(MsgRwrite, func() message { return &Rwrite{} }) - register(MsgTclunk, func() message { return &Tclunk{} }) - register(MsgRclunk, func() message { return &Rclunk{} }) - register(MsgTremove, func() message { return &Tremove{} }) - register(MsgRremove, func() message { return &Rremove{} }) - register(MsgTflushf, func() message { return &Tflushf{} }) - register(MsgRflushf, func() message { return &Rflushf{} }) - register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} }) - register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} }) - register(MsgTucreate, func() message { return &Tucreate{} }) - register(MsgRucreate, func() message { return &Rucreate{} }) - register(MsgTumkdir, func() message { return &Tumkdir{} }) - register(MsgRumkdir, func() message { return &Rumkdir{} }) - register(MsgTumknod, func() message { return &Tumknod{} }) - register(MsgRumknod, func() message { return &Rumknod{} }) - register(MsgTusymlink, func() message { return &Tusymlink{} }) - register(MsgRusymlink, func() message { return &Rusymlink{} }) - register(MsgTlconnect, func() message { return &Tlconnect{} }) - register(MsgRlconnect, func() message { return &Rlconnect{} }) - - calculateLargestFixedSize() + msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) + msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} }) + msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} }) + msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} }) + msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} }) + msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} }) + msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} }) + msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} }) + msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} }) + msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} }) + msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} }) + msgRegistry.register(MsgTrename, func() message { return &Trename{} }) + msgRegistry.register(MsgRrename, func() message { return &Rrename{} }) + msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} }) + msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} }) + msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} }) + msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} }) + msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} }) + msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} }) + msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} }) + msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} }) + msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} }) + msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} }) + msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} }) + msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} }) + msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} }) + msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} }) + msgRegistry.register(MsgTlink, func() message { return &Tlink{} }) + msgRegistry.register(MsgRlink, func() message { return &Rlink{} }) + msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} }) + msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} }) + msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} }) + msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} }) + msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} }) + msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} }) + msgRegistry.register(MsgTversion, func() message { return &Tversion{} }) + msgRegistry.register(MsgRversion, func() message { return &Rversion{} }) + msgRegistry.register(MsgTauth, func() message { return &Tauth{} }) + msgRegistry.register(MsgRauth, func() message { return &Rauth{} }) + msgRegistry.register(MsgTattach, func() message { return &Tattach{} }) + msgRegistry.register(MsgRattach, func() message { return &Rattach{} }) + msgRegistry.register(MsgTflush, func() message { return &Tflush{} }) + msgRegistry.register(MsgRflush, func() message { return &Rflush{} }) + msgRegistry.register(MsgTwalk, func() message { return &Twalk{} }) + msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} }) + msgRegistry.register(MsgTread, func() message { return &Tread{} }) + msgRegistry.register(MsgRread, func() message { return &Rread{} }) + msgRegistry.register(MsgTwrite, func() message { return &Twrite{} }) + msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} }) + msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} }) + msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} }) + msgRegistry.register(MsgTremove, func() message { return &Tremove{} }) + msgRegistry.register(MsgRremove, func() message { return &Rremove{} }) + msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} }) + msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} }) + msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} }) + msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} }) + msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} }) + msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} }) + msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} }) + msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} }) + msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} }) + msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} }) + msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} }) + msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} }) + msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} }) + msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) } diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go index 10a0587cf..513b30e8b 100644 --- a/pkg/p9/messages_test.go +++ b/pkg/p9/messages_test.go @@ -399,8 +399,9 @@ func TestEncodeDecode(t *testing.T) { } func TestMessageStrings(t *testing.T) { - for typ, fn := range messageRegistry { - if fn != nil { + for typ := range msgRegistry.factories { + entry := &msgRegistry.factories[typ] + if entry.create != nil { name := fmt.Sprintf("%+v", typ) t.Run(name, func(t *testing.T) { defer func() { // Ensure no panic. @@ -408,7 +409,7 @@ func TestMessageStrings(t *testing.T) { t.Errorf("printing %s failed: %v", name, r) } }() - m := fn() + m := entry.create() _ = fmt.Sprintf("%v", m) err := ErrInvalidMsgType{MsgType(typ)} _ = err.Error() @@ -426,5 +427,42 @@ func TestRegisterDuplicate(t *testing.T) { }() // Register a duplicate. - register(MsgRlerror, func() message { return &Rlerror{} }) + msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) +} + +func TestMsgCache(t *testing.T) { + // Cache starts empty. + if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } + + // Message can be created with an empty cache. + msg, err := msgRegistry.get(0, MsgRlerror) + if err != nil { + t.Errorf("msgRegistry.get(): %v", err) + } + if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } + + // Check that message is added to the cache when returned. + msgRegistry.put(msg) + if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } + + // Check that returned message is reused. + if got, err := msgRegistry.get(0, MsgRlerror); err != nil { + t.Errorf("msgRegistry.get(): %v", err) + } else if msg != got { + t.Errorf("Message not reused, got: %d, want: %d", got, msg) + } + + // Check that cache doesn't grow beyond max size. + for i := 0; i < maxCacheSize+1; i++ { + msgRegistry.put(&Rlerror{}) + } + if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want { + t.Errorf("Wrong cache size, got: %d, want: %d", got, want) + } } diff --git a/pkg/p9/server.go b/pkg/p9/server.go index b2a86d8fa..f377a6557 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -395,7 +395,7 @@ func (cs *connState) handleRequest() { } // Receive a message. - tag, m, err := recv(cs.conn, messageSize, messageByType) + tag, m, err := recv(cs.conn, messageSize, msgRegistry.get) if errSocket, ok := err.(ErrSocket); ok { // Connection problem; stop serving. cs.recvDone <- errSocket.error @@ -458,6 +458,8 @@ func (cs *connState) handleRequest() { // Produce an ENOSYS error. r = newErr(syscall.ENOSYS) } + msgRegistry.put(m) + m = nil // 'm' should not be touched after this point. } func (cs *connState) handleRequests() { diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index c833d1c9c..0f88ff249 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -41,7 +41,7 @@ func TestSendRecv(t *testing.T) { t.Fatalf("send got err %v expected nil", err) } - tag, m, err := recv(server, maximumLength, messageByType) + tag, m, err := recv(server, maximumLength, msgRegistry.get) if err != nil { t.Fatalf("recv got err %v expected nil", err) } @@ -73,7 +73,7 @@ func TestRecvOverrun(t *testing.T) { t.Fatalf("send got err %v expected nil", err) } - if _, _, err := recv(server, maximumLength, messageByType); err == nil { + if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil { t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err) } } @@ -98,7 +98,7 @@ func TestRecvInvalidType(t *testing.T) { t.Fatalf("send got err %v expected nil", err) } - _, _, err = recv(server, maximumLength, messageByType) + _, _, err = recv(server, maximumLength, msgRegistry.get) if _, ok := err.(*ErrInvalidMsgType); !ok { t.Fatalf("recv got err %v expected ErrInvalidMsgType", err) } @@ -129,7 +129,7 @@ func TestSendRecvWithFile(t *testing.T) { } // Enable withFile. - tag, m, err := recv(server, maximumLength, messageByType) + tag, m, err := recv(server, maximumLength, msgRegistry.get) if err != nil { t.Fatalf("recv got err %v expected nil", err) } @@ -153,7 +153,7 @@ func TestRecvClosed(t *testing.T) { defer server.Close() client.Close() - _, _, err = recv(server, maximumLength, messageByType) + _, _, err = recv(server, maximumLength, msgRegistry.get) if err == nil { t.Fatalf("got err nil expected non-nil") } @@ -180,5 +180,5 @@ func TestSendClosed(t *testing.T) { } func init() { - register(MsgTypeBadDecode, func() message { return &badDecode{} }) + msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} }) } |