summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9
diff options
context:
space:
mode:
authorFabricio Voznika <fvoznika@google.com>2019-04-29 15:32:45 -0700
committerShentubot <shentubot@google.com>2019-04-29 15:33:47 -0700
commitddab854b9a895603664fa4abfa525f6a29047083 (patch)
tree349b428a042c7e7e0d7de71e56a319f7df3ae29b /pkg/p9
parent4d52a5520101a88424fb63dd99412a1db33fbd06 (diff)
Reduce memory allocations on serving path
Cache last used messages and reuse them for subsequent requests. If more messages are needed, they are created outside the cache on demand. PiperOrigin-RevId: 245836910 Change-Id: Icf099ddff95df420db8e09f5cdd41dcdce406c61
Diffstat (limited to 'pkg/p9')
-rw-r--r--pkg/p9/buffer.go3
-rw-r--r--pkg/p9/client.go6
-rw-r--r--pkg/p9/messages.go232
-rw-r--r--pkg/p9/messages_test.go46
-rw-r--r--pkg/p9/server.go4
-rw-r--r--pkg/p9/transport_test.go12
6 files changed, 187 insertions, 116 deletions
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
index 4c8c6555d..249536d8a 100644
--- a/pkg/p9/buffer.go
+++ b/pkg/p9/buffer.go
@@ -20,7 +20,8 @@ import (
// encoder is used for messages and 9P primitives.
type encoder interface {
- // Decode decodes from the given buffer.
+ // Decode decodes from the given buffer. Decode may be called more than once
+ // to reuse the instance. It must clear any previous state.
//
// This may not fail, exhaustion will be recorded in the buffer.
Decode(b *buffer)
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 2f9c716d0..56587e2cf 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -110,16 +110,16 @@ type Client struct {
// You should not use the same socket for multiple clients.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
- if messageSize <= largestFixedSize {
+ if messageSize <= msgRegistry.largestFixedSize {
return nil, &ErrMessageTooLarge{
size: messageSize,
- msize: largestFixedSize,
+ msize: msgRegistry.largestFixedSize,
}
}
// Compute a payload size and round to 512 (normal block size)
// if it's larger than a single block.
- payloadSize := messageSize - largestFixedSize
+ payloadSize := messageSize - msgRegistry.largestFixedSize
if payloadSize > 512 && payloadSize%512 != 0 {
payloadSize -= (payloadSize % 512)
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 833defbd6..3c7898cc1 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -193,6 +193,7 @@ func (t *Twalk) Decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
+ t.Names = t.Names[:0]
for i := 0; i < int(n); i++ {
t.Names = append(t.Names, b.ReadString())
}
@@ -227,6 +228,7 @@ type Rwalk struct {
// Decode implements encoder.Decode.
func (r *Rwalk) Decode(b *buffer) {
n := b.Read16()
+ r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
q.Decode(b)
@@ -1608,6 +1610,7 @@ type Rreaddir struct {
func (r *Rreaddir) Decode(b *buffer) {
r.Count = b.Read32()
entriesBuf := buffer{data: r.payload}
+ r.Entries = r.Entries[:0]
for {
var d Dirent
d.Decode(&entriesBuf)
@@ -1827,6 +1830,7 @@ func (t *Twalkgetattr) Decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
+ t.Names = t.Names[:0]
for i := 0; i < int(n); i++ {
t.Names = append(t.Names, b.ReadString())
}
@@ -1869,6 +1873,7 @@ func (r *Rwalkgetattr) Decode(b *buffer) {
r.Valid.Decode(b)
r.Attr.Decode(b)
n := b.Read16()
+ r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
q.Decode(b)
@@ -2139,34 +2144,80 @@ func (r *Rlconnect) String() string {
return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// messageRegistry indexes all messages by type.
-var messageRegistry = make([]func() message, math.MaxUint8)
+const maxCacheSize = 3
-// messageByType creates a new message by type.
+// msgFactory is used to reduce allocations by caching messages for reuse.
+type msgFactory struct {
+ create func() message
+ cache chan message
+}
+
+// msgRegistry indexes all message factories by type.
+var msgRegistry registry
+
+type registry struct {
+ factories [math.MaxUint8]msgFactory
+
+ // largestFixedSize is computed so that given some message size M, you can
+ // compute the maximum payload size (e.g. for Twrite, Rread) with
+ // M-largestFixedSize. You could do this individual on a per-message basis,
+ // but it's easier to compute a single maximum safe payload.
+ largestFixedSize uint32
+}
+
+// get returns a new message by type.
//
// An error is returned in the case of an unknown message.
//
// This takes, and ignores, a message tag so that it may be used directly as a
// lookupTagAndType function for recv (by design).
-func messageByType(_ Tag, t MsgType) (message, error) {
- fn := messageRegistry[t]
- if fn == nil {
+func (r *registry) get(_ Tag, t MsgType) (message, error) {
+ entry := &r.factories[t]
+ if entry.create == nil {
return nil, &ErrInvalidMsgType{t}
}
- return fn(), nil
+
+ select {
+ case msg := <-entry.cache:
+ return msg, nil
+ default:
+ return entry.create(), nil
+ }
+}
+
+func (r *registry) put(msg message) {
+ if p, ok := msg.(payloader); ok {
+ p.SetPayload(nil)
+ }
+ if f, ok := msg.(filer); ok {
+ f.SetFilePayload(nil)
+ }
+
+ entry := &r.factories[msg.Type()]
+ select {
+ case entry.cache <- msg:
+ default:
+ }
}
// register registers the given message type.
//
// This may cause panic on failure and should only be used from init.
-func register(t MsgType, fn func() message) {
- if int(t) >= len(messageRegistry) {
- panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(messageRegistry)))
+func (r *registry) register(t MsgType, fn func() message) {
+ if int(t) >= len(r.factories) {
+ panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories)))
}
- if messageRegistry[t] != nil {
- panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, messageRegistry[t](), fn()))
+ if r.factories[t].create != nil {
+ panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn()))
+ }
+ r.factories[t] = msgFactory{
+ create: fn,
+ cache: make(chan message, maxCacheSize),
+ }
+
+ if size := calculateSize(fn()); size > r.largestFixedSize {
+ r.largestFixedSize = size
}
- messageRegistry[t] = fn
}
func calculateSize(m message) uint32 {
@@ -2178,93 +2229,72 @@ func calculateSize(m message) uint32 {
return uint32(len(dataBuf.data))
}
-// largestFixedSize is computed within calculateLargestSize.
-//
-// This is computed so that given some message size M, you can compute
-// the maximum payload size (e.g. for Twrite, Rread) with M-largestFixedSize.
-// You could do this individual on a per-message basis, but it's easier to
-// compute a single maximum safe payload.
-var largestFixedSize uint32
-
-// calculateLargestFixedSize is called from within init.
-func calculateLargestFixedSize() {
- for _, fn := range messageRegistry {
- if fn != nil {
- if size := calculateSize(fn()); size > largestFixedSize {
- largestFixedSize = size
- }
- }
- }
-}
-
func init() {
- register(MsgRlerror, func() message { return &Rlerror{} })
- register(MsgTstatfs, func() message { return &Tstatfs{} })
- register(MsgRstatfs, func() message { return &Rstatfs{} })
- register(MsgTlopen, func() message { return &Tlopen{} })
- register(MsgRlopen, func() message { return &Rlopen{} })
- register(MsgTlcreate, func() message { return &Tlcreate{} })
- register(MsgRlcreate, func() message { return &Rlcreate{} })
- register(MsgTsymlink, func() message { return &Tsymlink{} })
- register(MsgRsymlink, func() message { return &Rsymlink{} })
- register(MsgTmknod, func() message { return &Tmknod{} })
- register(MsgRmknod, func() message { return &Rmknod{} })
- register(MsgTrename, func() message { return &Trename{} })
- register(MsgRrename, func() message { return &Rrename{} })
- register(MsgTreadlink, func() message { return &Treadlink{} })
- register(MsgRreadlink, func() message { return &Rreadlink{} })
- register(MsgTgetattr, func() message { return &Tgetattr{} })
- register(MsgRgetattr, func() message { return &Rgetattr{} })
- register(MsgTsetattr, func() message { return &Tsetattr{} })
- register(MsgRsetattr, func() message { return &Rsetattr{} })
- register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
- register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
- register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
- register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
- register(MsgTreaddir, func() message { return &Treaddir{} })
- register(MsgRreaddir, func() message { return &Rreaddir{} })
- register(MsgTfsync, func() message { return &Tfsync{} })
- register(MsgRfsync, func() message { return &Rfsync{} })
- register(MsgTlink, func() message { return &Tlink{} })
- register(MsgRlink, func() message { return &Rlink{} })
- register(MsgTmkdir, func() message { return &Tmkdir{} })
- register(MsgRmkdir, func() message { return &Rmkdir{} })
- register(MsgTrenameat, func() message { return &Trenameat{} })
- register(MsgRrenameat, func() message { return &Rrenameat{} })
- register(MsgTunlinkat, func() message { return &Tunlinkat{} })
- register(MsgRunlinkat, func() message { return &Runlinkat{} })
- register(MsgTversion, func() message { return &Tversion{} })
- register(MsgRversion, func() message { return &Rversion{} })
- register(MsgTauth, func() message { return &Tauth{} })
- register(MsgRauth, func() message { return &Rauth{} })
- register(MsgTattach, func() message { return &Tattach{} })
- register(MsgRattach, func() message { return &Rattach{} })
- register(MsgTflush, func() message { return &Tflush{} })
- register(MsgRflush, func() message { return &Rflush{} })
- register(MsgTwalk, func() message { return &Twalk{} })
- register(MsgRwalk, func() message { return &Rwalk{} })
- register(MsgTread, func() message { return &Tread{} })
- register(MsgRread, func() message { return &Rread{} })
- register(MsgTwrite, func() message { return &Twrite{} })
- register(MsgRwrite, func() message { return &Rwrite{} })
- register(MsgTclunk, func() message { return &Tclunk{} })
- register(MsgRclunk, func() message { return &Rclunk{} })
- register(MsgTremove, func() message { return &Tremove{} })
- register(MsgRremove, func() message { return &Rremove{} })
- register(MsgTflushf, func() message { return &Tflushf{} })
- register(MsgRflushf, func() message { return &Rflushf{} })
- register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
- register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
- register(MsgTucreate, func() message { return &Tucreate{} })
- register(MsgRucreate, func() message { return &Rucreate{} })
- register(MsgTumkdir, func() message { return &Tumkdir{} })
- register(MsgRumkdir, func() message { return &Rumkdir{} })
- register(MsgTumknod, func() message { return &Tumknod{} })
- register(MsgRumknod, func() message { return &Rumknod{} })
- register(MsgTusymlink, func() message { return &Tusymlink{} })
- register(MsgRusymlink, func() message { return &Rusymlink{} })
- register(MsgTlconnect, func() message { return &Tlconnect{} })
- register(MsgRlconnect, func() message { return &Rlconnect{} })
-
- calculateLargestFixedSize()
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} })
+ msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} })
+ msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} })
+ msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} })
+ msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} })
+ msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} })
+ msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} })
+ msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} })
+ msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} })
+ msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} })
+ msgRegistry.register(MsgTrename, func() message { return &Trename{} })
+ msgRegistry.register(MsgRrename, func() message { return &Rrename{} })
+ msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} })
+ msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} })
+ msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} })
+ msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
+ msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
+ msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
+ msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
+ msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
+ msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
+ msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
+ msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
+ msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} })
+ msgRegistry.register(MsgTlink, func() message { return &Tlink{} })
+ msgRegistry.register(MsgRlink, func() message { return &Rlink{} })
+ msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} })
+ msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} })
+ msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} })
+ msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} })
+ msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} })
+ msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} })
+ msgRegistry.register(MsgTversion, func() message { return &Tversion{} })
+ msgRegistry.register(MsgRversion, func() message { return &Rversion{} })
+ msgRegistry.register(MsgTauth, func() message { return &Tauth{} })
+ msgRegistry.register(MsgRauth, func() message { return &Rauth{} })
+ msgRegistry.register(MsgTattach, func() message { return &Tattach{} })
+ msgRegistry.register(MsgRattach, func() message { return &Rattach{} })
+ msgRegistry.register(MsgTflush, func() message { return &Tflush{} })
+ msgRegistry.register(MsgRflush, func() message { return &Rflush{} })
+ msgRegistry.register(MsgTwalk, func() message { return &Twalk{} })
+ msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} })
+ msgRegistry.register(MsgTread, func() message { return &Tread{} })
+ msgRegistry.register(MsgRread, func() message { return &Rread{} })
+ msgRegistry.register(MsgTwrite, func() message { return &Twrite{} })
+ msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} })
+ msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} })
+ msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} })
+ msgRegistry.register(MsgTremove, func() message { return &Tremove{} })
+ msgRegistry.register(MsgRremove, func() message { return &Rremove{} })
+ msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} })
+ msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} })
+ msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
+ msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
+ msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} })
+ msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} })
+ msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} })
+ msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} })
+ msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} })
+ msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} })
+ msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} })
+ msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} })
+ msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} })
+ msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
}
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 10a0587cf..513b30e8b 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -399,8 +399,9 @@ func TestEncodeDecode(t *testing.T) {
}
func TestMessageStrings(t *testing.T) {
- for typ, fn := range messageRegistry {
- if fn != nil {
+ for typ := range msgRegistry.factories {
+ entry := &msgRegistry.factories[typ]
+ if entry.create != nil {
name := fmt.Sprintf("%+v", typ)
t.Run(name, func(t *testing.T) {
defer func() { // Ensure no panic.
@@ -408,7 +409,7 @@ func TestMessageStrings(t *testing.T) {
t.Errorf("printing %s failed: %v", name, r)
}
}()
- m := fn()
+ m := entry.create()
_ = fmt.Sprintf("%v", m)
err := ErrInvalidMsgType{MsgType(typ)}
_ = err.Error()
@@ -426,5 +427,42 @@ func TestRegisterDuplicate(t *testing.T) {
}()
// Register a duplicate.
- register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+}
+
+func TestMsgCache(t *testing.T) {
+ // Cache starts empty.
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Message can be created with an empty cache.
+ msg, err := msgRegistry.get(0, MsgRlerror)
+ if err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that message is added to the cache when returned.
+ msgRegistry.put(msg)
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that returned message is reused.
+ if got, err := msgRegistry.get(0, MsgRlerror); err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ } else if msg != got {
+ t.Errorf("Message not reused, got: %d, want: %d", got, msg)
+ }
+
+ // Check that cache doesn't grow beyond max size.
+ for i := 0; i < maxCacheSize+1; i++ {
+ msgRegistry.put(&Rlerror{})
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b2a86d8fa..f377a6557 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -395,7 +395,7 @@ func (cs *connState) handleRequest() {
}
// Receive a message.
- tag, m, err := recv(cs.conn, messageSize, messageByType)
+ tag, m, err := recv(cs.conn, messageSize, msgRegistry.get)
if errSocket, ok := err.(ErrSocket); ok {
// Connection problem; stop serving.
cs.recvDone <- errSocket.error
@@ -458,6 +458,8 @@ func (cs *connState) handleRequest() {
// Produce an ENOSYS error.
r = newErr(syscall.ENOSYS)
}
+ msgRegistry.put(m)
+ m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index c833d1c9c..0f88ff249 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -41,7 +41,7 @@ func TestSendRecv(t *testing.T) {
t.Fatalf("send got err %v expected nil", err)
}
- tag, m, err := recv(server, maximumLength, messageByType)
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
t.Fatalf("recv got err %v expected nil", err)
}
@@ -73,7 +73,7 @@ func TestRecvOverrun(t *testing.T) {
t.Fatalf("send got err %v expected nil", err)
}
- if _, _, err := recv(server, maximumLength, messageByType); err == nil {
+ if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil {
t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err)
}
}
@@ -98,7 +98,7 @@ func TestRecvInvalidType(t *testing.T) {
t.Fatalf("send got err %v expected nil", err)
}
- _, _, err = recv(server, maximumLength, messageByType)
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
if _, ok := err.(*ErrInvalidMsgType); !ok {
t.Fatalf("recv got err %v expected ErrInvalidMsgType", err)
}
@@ -129,7 +129,7 @@ func TestSendRecvWithFile(t *testing.T) {
}
// Enable withFile.
- tag, m, err := recv(server, maximumLength, messageByType)
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
t.Fatalf("recv got err %v expected nil", err)
}
@@ -153,7 +153,7 @@ func TestRecvClosed(t *testing.T) {
defer server.Close()
client.Close()
- _, _, err = recv(server, maximumLength, messageByType)
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
if err == nil {
t.Fatalf("got err nil expected non-nil")
}
@@ -180,5 +180,5 @@ func TestSendClosed(t *testing.T) {
}
func init() {
- register(MsgTypeBadDecode, func() message { return &badDecode{} })
+ msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} })
}