summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/p9/buffer.go3
-rw-r--r--pkg/p9/client.go6
-rw-r--r--pkg/p9/messages.go232
-rw-r--r--pkg/p9/messages_test.go46
-rw-r--r--pkg/p9/server.go4
-rw-r--r--pkg/p9/transport_test.go12
6 files changed, 187 insertions, 116 deletions
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
index 4c8c6555d..249536d8a 100644
--- a/pkg/p9/buffer.go
+++ b/pkg/p9/buffer.go
@@ -20,7 +20,8 @@ import (
// encoder is used for messages and 9P primitives.
type encoder interface {
- // Decode decodes from the given buffer.
+ // Decode decodes from the given buffer. Decode may be called more than once
+ // to reuse the instance. It must clear any previous state.
//
// This may not fail, exhaustion will be recorded in the buffer.
Decode(b *buffer)
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 2f9c716d0..56587e2cf 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -110,16 +110,16 @@ type Client struct {
// You should not use the same socket for multiple clients.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
- if messageSize <= largestFixedSize {
+ if messageSize <= msgRegistry.largestFixedSize {
return nil, &ErrMessageTooLarge{
size: messageSize,
- msize: largestFixedSize,
+ msize: msgRegistry.largestFixedSize,
}
}
// Compute a payload size and round to 512 (normal block size)
// if it's larger than a single block.
- payloadSize := messageSize - largestFixedSize
+ payloadSize := messageSize - msgRegistry.largestFixedSize
if payloadSize > 512 && payloadSize%512 != 0 {
payloadSize -= (payloadSize % 512)
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 833defbd6..3c7898cc1 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -193,6 +193,7 @@ func (t *Twalk) Decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
+ t.Names = t.Names[:0]
for i := 0; i < int(n); i++ {
t.Names = append(t.Names, b.ReadString())
}
@@ -227,6 +228,7 @@ type Rwalk struct {
// Decode implements encoder.Decode.
func (r *Rwalk) Decode(b *buffer) {
n := b.Read16()
+ r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
q.Decode(b)
@@ -1608,6 +1610,7 @@ type Rreaddir struct {
func (r *Rreaddir) Decode(b *buffer) {
r.Count = b.Read32()
entriesBuf := buffer{data: r.payload}
+ r.Entries = r.Entries[:0]
for {
var d Dirent
d.Decode(&entriesBuf)
@@ -1827,6 +1830,7 @@ func (t *Twalkgetattr) Decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
+ t.Names = t.Names[:0]
for i := 0; i < int(n); i++ {
t.Names = append(t.Names, b.ReadString())
}
@@ -1869,6 +1873,7 @@ func (r *Rwalkgetattr) Decode(b *buffer) {
r.Valid.Decode(b)
r.Attr.Decode(b)
n := b.Read16()
+ r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
q.Decode(b)
@@ -2139,34 +2144,80 @@ func (r *Rlconnect) String() string {
return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// messageRegistry indexes all messages by type.
-var messageRegistry = make([]func() message, math.MaxUint8)
+const maxCacheSize = 3
-// messageByType creates a new message by type.
+// msgFactory is used to reduce allocations by caching messages for reuse.
+type msgFactory struct {
+ create func() message
+ cache chan message
+}
+
+// msgRegistry indexes all message factories by type.
+var msgRegistry registry
+
+type registry struct {
+ factories [math.MaxUint8]msgFactory
+
+ // largestFixedSize is computed so that given some message size M, you can
+ // compute the maximum payload size (e.g. for Twrite, Rread) with
+ // M-largestFixedSize. You could do this individual on a per-message basis,
+ // but it's easier to compute a single maximum safe payload.
+ largestFixedSize uint32
+}
+
+// get returns a new message by type.
//
// An error is returned in the case of an unknown message.
//
// This takes, and ignores, a message tag so that it may be used directly as a
// lookupTagAndType function for recv (by design).
-func messageByType(_ Tag, t MsgType) (message, error) {
- fn := messageRegistry[t]
- if fn == nil {
+func (r *registry) get(_ Tag, t MsgType) (message, error) {
+ entry := &r.factories[t]
+ if entry.create == nil {
return nil, &ErrInvalidMsgType{t}
}
- return fn(), nil
+
+ select {
+ case msg := <-entry.cache:
+ return msg, nil
+ default:
+ return entry.create(), nil
+ }
+}
+
+func (r *registry) put(msg message) {
+ if p, ok := msg.(payloader); ok {
+ p.SetPayload(nil)
+ }
+ if f, ok := msg.(filer); ok {
+ f.SetFilePayload(nil)
+ }
+
+ entry := &r.factories[msg.Type()]
+ select {
+ case entry.cache <- msg:
+ default:
+ }
}
// register registers the given message type.
//
// This may cause panic on failure and should only be used from init.
-func register(t MsgType, fn func() message) {
- if int(t) >= len(messageRegistry) {
- panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(messageRegistry)))
+func (r *registry) register(t MsgType, fn func() message) {
+ if int(t) >= len(r.factories) {
+ panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(r.factories)))
}
- if messageRegistry[t] != nil {
- panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, messageRegistry[t](), fn()))
+ if r.factories[t].create != nil {
+ panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, r.factories[t].create(), fn()))
+ }
+ r.factories[t] = msgFactory{
+ create: fn,
+ cache: make(chan message, maxCacheSize),
+ }
+
+ if size := calculateSize(fn()); size > r.largestFixedSize {
+ r.largestFixedSize = size
}
- messageRegistry[t] = fn
}
func calculateSize(m message) uint32 {
@@ -2178,93 +2229,72 @@ func calculateSize(m message) uint32 {
return uint32(len(dataBuf.data))
}
-// largestFixedSize is computed within calculateLargestSize.
-//
-// This is computed so that given some message size M, you can compute
-// the maximum payload size (e.g. for Twrite, Rread) with M-largestFixedSize.
-// You could do this individual on a per-message basis, but it's easier to
-// compute a single maximum safe payload.
-var largestFixedSize uint32
-
-// calculateLargestFixedSize is called from within init.
-func calculateLargestFixedSize() {
- for _, fn := range messageRegistry {
- if fn != nil {
- if size := calculateSize(fn()); size > largestFixedSize {
- largestFixedSize = size
- }
- }
- }
-}
-
func init() {
- register(MsgRlerror, func() message { return &Rlerror{} })
- register(MsgTstatfs, func() message { return &Tstatfs{} })
- register(MsgRstatfs, func() message { return &Rstatfs{} })
- register(MsgTlopen, func() message { return &Tlopen{} })
- register(MsgRlopen, func() message { return &Rlopen{} })
- register(MsgTlcreate, func() message { return &Tlcreate{} })
- register(MsgRlcreate, func() message { return &Rlcreate{} })
- register(MsgTsymlink, func() message { return &Tsymlink{} })
- register(MsgRsymlink, func() message { return &Rsymlink{} })
- register(MsgTmknod, func() message { return &Tmknod{} })
- register(MsgRmknod, func() message { return &Rmknod{} })
- register(MsgTrename, func() message { return &Trename{} })
- register(MsgRrename, func() message { return &Rrename{} })
- register(MsgTreadlink, func() message { return &Treadlink{} })
- register(MsgRreadlink, func() message { return &Rreadlink{} })
- register(MsgTgetattr, func() message { return &Tgetattr{} })
- register(MsgRgetattr, func() message { return &Rgetattr{} })
- register(MsgTsetattr, func() message { return &Tsetattr{} })
- register(MsgRsetattr, func() message { return &Rsetattr{} })
- register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
- register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
- register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
- register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
- register(MsgTreaddir, func() message { return &Treaddir{} })
- register(MsgRreaddir, func() message { return &Rreaddir{} })
- register(MsgTfsync, func() message { return &Tfsync{} })
- register(MsgRfsync, func() message { return &Rfsync{} })
- register(MsgTlink, func() message { return &Tlink{} })
- register(MsgRlink, func() message { return &Rlink{} })
- register(MsgTmkdir, func() message { return &Tmkdir{} })
- register(MsgRmkdir, func() message { return &Rmkdir{} })
- register(MsgTrenameat, func() message { return &Trenameat{} })
- register(MsgRrenameat, func() message { return &Rrenameat{} })
- register(MsgTunlinkat, func() message { return &Tunlinkat{} })
- register(MsgRunlinkat, func() message { return &Runlinkat{} })
- register(MsgTversion, func() message { return &Tversion{} })
- register(MsgRversion, func() message { return &Rversion{} })
- register(MsgTauth, func() message { return &Tauth{} })
- register(MsgRauth, func() message { return &Rauth{} })
- register(MsgTattach, func() message { return &Tattach{} })
- register(MsgRattach, func() message { return &Rattach{} })
- register(MsgTflush, func() message { return &Tflush{} })
- register(MsgRflush, func() message { return &Rflush{} })
- register(MsgTwalk, func() message { return &Twalk{} })
- register(MsgRwalk, func() message { return &Rwalk{} })
- register(MsgTread, func() message { return &Tread{} })
- register(MsgRread, func() message { return &Rread{} })
- register(MsgTwrite, func() message { return &Twrite{} })
- register(MsgRwrite, func() message { return &Rwrite{} })
- register(MsgTclunk, func() message { return &Tclunk{} })
- register(MsgRclunk, func() message { return &Rclunk{} })
- register(MsgTremove, func() message { return &Tremove{} })
- register(MsgRremove, func() message { return &Rremove{} })
- register(MsgTflushf, func() message { return &Tflushf{} })
- register(MsgRflushf, func() message { return &Rflushf{} })
- register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
- register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
- register(MsgTucreate, func() message { return &Tucreate{} })
- register(MsgRucreate, func() message { return &Rucreate{} })
- register(MsgTumkdir, func() message { return &Tumkdir{} })
- register(MsgRumkdir, func() message { return &Rumkdir{} })
- register(MsgTumknod, func() message { return &Tumknod{} })
- register(MsgRumknod, func() message { return &Rumknod{} })
- register(MsgTusymlink, func() message { return &Tusymlink{} })
- register(MsgRusymlink, func() message { return &Rusymlink{} })
- register(MsgTlconnect, func() message { return &Tlconnect{} })
- register(MsgRlconnect, func() message { return &Rlconnect{} })
-
- calculateLargestFixedSize()
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgTstatfs, func() message { return &Tstatfs{} })
+ msgRegistry.register(MsgRstatfs, func() message { return &Rstatfs{} })
+ msgRegistry.register(MsgTlopen, func() message { return &Tlopen{} })
+ msgRegistry.register(MsgRlopen, func() message { return &Rlopen{} })
+ msgRegistry.register(MsgTlcreate, func() message { return &Tlcreate{} })
+ msgRegistry.register(MsgRlcreate, func() message { return &Rlcreate{} })
+ msgRegistry.register(MsgTsymlink, func() message { return &Tsymlink{} })
+ msgRegistry.register(MsgRsymlink, func() message { return &Rsymlink{} })
+ msgRegistry.register(MsgTmknod, func() message { return &Tmknod{} })
+ msgRegistry.register(MsgRmknod, func() message { return &Rmknod{} })
+ msgRegistry.register(MsgTrename, func() message { return &Trename{} })
+ msgRegistry.register(MsgRrename, func() message { return &Rrename{} })
+ msgRegistry.register(MsgTreadlink, func() message { return &Treadlink{} })
+ msgRegistry.register(MsgRreadlink, func() message { return &Rreadlink{} })
+ msgRegistry.register(MsgTgetattr, func() message { return &Tgetattr{} })
+ msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
+ msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
+ msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
+ msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
+ msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
+ msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
+ msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
+ msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
+ msgRegistry.register(MsgRfsync, func() message { return &Rfsync{} })
+ msgRegistry.register(MsgTlink, func() message { return &Tlink{} })
+ msgRegistry.register(MsgRlink, func() message { return &Rlink{} })
+ msgRegistry.register(MsgTmkdir, func() message { return &Tmkdir{} })
+ msgRegistry.register(MsgRmkdir, func() message { return &Rmkdir{} })
+ msgRegistry.register(MsgTrenameat, func() message { return &Trenameat{} })
+ msgRegistry.register(MsgRrenameat, func() message { return &Rrenameat{} })
+ msgRegistry.register(MsgTunlinkat, func() message { return &Tunlinkat{} })
+ msgRegistry.register(MsgRunlinkat, func() message { return &Runlinkat{} })
+ msgRegistry.register(MsgTversion, func() message { return &Tversion{} })
+ msgRegistry.register(MsgRversion, func() message { return &Rversion{} })
+ msgRegistry.register(MsgTauth, func() message { return &Tauth{} })
+ msgRegistry.register(MsgRauth, func() message { return &Rauth{} })
+ msgRegistry.register(MsgTattach, func() message { return &Tattach{} })
+ msgRegistry.register(MsgRattach, func() message { return &Rattach{} })
+ msgRegistry.register(MsgTflush, func() message { return &Tflush{} })
+ msgRegistry.register(MsgRflush, func() message { return &Rflush{} })
+ msgRegistry.register(MsgTwalk, func() message { return &Twalk{} })
+ msgRegistry.register(MsgRwalk, func() message { return &Rwalk{} })
+ msgRegistry.register(MsgTread, func() message { return &Tread{} })
+ msgRegistry.register(MsgRread, func() message { return &Rread{} })
+ msgRegistry.register(MsgTwrite, func() message { return &Twrite{} })
+ msgRegistry.register(MsgRwrite, func() message { return &Rwrite{} })
+ msgRegistry.register(MsgTclunk, func() message { return &Tclunk{} })
+ msgRegistry.register(MsgRclunk, func() message { return &Rclunk{} })
+ msgRegistry.register(MsgTremove, func() message { return &Tremove{} })
+ msgRegistry.register(MsgRremove, func() message { return &Rremove{} })
+ msgRegistry.register(MsgTflushf, func() message { return &Tflushf{} })
+ msgRegistry.register(MsgRflushf, func() message { return &Rflushf{} })
+ msgRegistry.register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} })
+ msgRegistry.register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} })
+ msgRegistry.register(MsgTucreate, func() message { return &Tucreate{} })
+ msgRegistry.register(MsgRucreate, func() message { return &Rucreate{} })
+ msgRegistry.register(MsgTumkdir, func() message { return &Tumkdir{} })
+ msgRegistry.register(MsgRumkdir, func() message { return &Rumkdir{} })
+ msgRegistry.register(MsgTumknod, func() message { return &Tumknod{} })
+ msgRegistry.register(MsgRumknod, func() message { return &Rumknod{} })
+ msgRegistry.register(MsgTusymlink, func() message { return &Tusymlink{} })
+ msgRegistry.register(MsgRusymlink, func() message { return &Rusymlink{} })
+ msgRegistry.register(MsgTlconnect, func() message { return &Tlconnect{} })
+ msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
}
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 10a0587cf..513b30e8b 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -399,8 +399,9 @@ func TestEncodeDecode(t *testing.T) {
}
func TestMessageStrings(t *testing.T) {
- for typ, fn := range messageRegistry {
- if fn != nil {
+ for typ := range msgRegistry.factories {
+ entry := &msgRegistry.factories[typ]
+ if entry.create != nil {
name := fmt.Sprintf("%+v", typ)
t.Run(name, func(t *testing.T) {
defer func() { // Ensure no panic.
@@ -408,7 +409,7 @@ func TestMessageStrings(t *testing.T) {
t.Errorf("printing %s failed: %v", name, r)
}
}()
- m := fn()
+ m := entry.create()
_ = fmt.Sprintf("%v", m)
err := ErrInvalidMsgType{MsgType(typ)}
_ = err.Error()
@@ -426,5 +427,42 @@ func TestRegisterDuplicate(t *testing.T) {
}()
// Register a duplicate.
- register(MsgRlerror, func() message { return &Rlerror{} })
+ msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} })
+}
+
+func TestMsgCache(t *testing.T) {
+ // Cache starts empty.
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Message can be created with an empty cache.
+ msg, err := msgRegistry.get(0, MsgRlerror)
+ if err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that message is added to the cache when returned.
+ msgRegistry.put(msg)
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
+
+ // Check that returned message is reused.
+ if got, err := msgRegistry.get(0, MsgRlerror); err != nil {
+ t.Errorf("msgRegistry.get(): %v", err)
+ } else if msg != got {
+ t.Errorf("Message not reused, got: %d, want: %d", got, msg)
+ }
+
+ // Check that cache doesn't grow beyond max size.
+ for i := 0; i < maxCacheSize+1; i++ {
+ msgRegistry.put(&Rlerror{})
+ }
+ if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want {
+ t.Errorf("Wrong cache size, got: %d, want: %d", got, want)
+ }
}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b2a86d8fa..f377a6557 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -395,7 +395,7 @@ func (cs *connState) handleRequest() {
}
// Receive a message.
- tag, m, err := recv(cs.conn, messageSize, messageByType)
+ tag, m, err := recv(cs.conn, messageSize, msgRegistry.get)
if errSocket, ok := err.(ErrSocket); ok {
// Connection problem; stop serving.
cs.recvDone <- errSocket.error
@@ -458,6 +458,8 @@ func (cs *connState) handleRequest() {
// Produce an ENOSYS error.
r = newErr(syscall.ENOSYS)
}
+ msgRegistry.put(m)
+ m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index c833d1c9c..0f88ff249 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -41,7 +41,7 @@ func TestSendRecv(t *testing.T) {
t.Fatalf("send got err %v expected nil", err)
}
- tag, m, err := recv(server, maximumLength, messageByType)
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
t.Fatalf("recv got err %v expected nil", err)
}
@@ -73,7 +73,7 @@ func TestRecvOverrun(t *testing.T) {
t.Fatalf("send got err %v expected nil", err)
}
- if _, _, err := recv(server, maximumLength, messageByType); err == nil {
+ if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil {
t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err)
}
}
@@ -98,7 +98,7 @@ func TestRecvInvalidType(t *testing.T) {
t.Fatalf("send got err %v expected nil", err)
}
- _, _, err = recv(server, maximumLength, messageByType)
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
if _, ok := err.(*ErrInvalidMsgType); !ok {
t.Fatalf("recv got err %v expected ErrInvalidMsgType", err)
}
@@ -129,7 +129,7 @@ func TestSendRecvWithFile(t *testing.T) {
}
// Enable withFile.
- tag, m, err := recv(server, maximumLength, messageByType)
+ tag, m, err := recv(server, maximumLength, msgRegistry.get)
if err != nil {
t.Fatalf("recv got err %v expected nil", err)
}
@@ -153,7 +153,7 @@ func TestRecvClosed(t *testing.T) {
defer server.Close()
client.Close()
- _, _, err = recv(server, maximumLength, messageByType)
+ _, _, err = recv(server, maximumLength, msgRegistry.get)
if err == nil {
t.Fatalf("got err nil expected non-nil")
}
@@ -180,5 +180,5 @@ func TestSendClosed(t *testing.T) {
}
func init() {
- register(MsgTypeBadDecode, func() message { return &badDecode{} })
+ msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} })
}