diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/p9/messages.go | 165 | ||||
-rw-r--r-- | pkg/p9/messages_test.go | 28 | ||||
-rw-r--r-- | pkg/p9/transport_test.go | 2 |
3 files changed, 98 insertions, 97 deletions
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 972c37344..97decd3cc 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -16,7 +16,7 @@ package p9 import ( "fmt" - "reflect" + "math" "gvisor.googlesource.com/gvisor/pkg/fd" ) @@ -2140,7 +2140,7 @@ func (r *Rlconnect) String() string { } // messageRegistry indexes all messages by type. -var messageRegistry = make(map[MsgType]func() message) +var messageRegistry = make([]func() message, math.MaxUint8) // messageByType creates a new message by type. // @@ -2149,8 +2149,8 @@ var messageRegistry = make(map[MsgType]func() message) // This takes, and ignores, a message tag so that it may be used directly as a // lookupTagAndType function for recv (by design). func messageByType(_ Tag, t MsgType) (message, error) { - fn, ok := messageRegistry[t] - if !ok { + fn := messageRegistry[t] + if fn == nil { return nil, &ErrInvalidMsgType{t} } return fn(), nil @@ -2158,18 +2158,15 @@ func messageByType(_ Tag, t MsgType) (message, error) { // register registers the given message type. // -// This uses reflection and records only the type. This may cause panic on -// failure and should only be used from init. -func register(m message) { - t := m.Type() - if fn, ok := messageRegistry[t]; ok { - panic(fmt.Sprintf("duplicate message type %d: first is %#v, second is %#v", t, fn(), m)) +// This may cause panic on failure and should only be used from init. +func register(t MsgType, fn func() message) { + if int(t) >= len(messageRegistry) { + panic(fmt.Sprintf("message type %d is too large. It must be smaller than %d", t, len(messageRegistry))) } - - to := reflect.ValueOf(m).Elem().Type() - messageRegistry[t] = func() message { - return reflect.New(to).Interface().(message) + if messageRegistry[t] != nil { + panic(fmt.Sprintf("duplicate message type %d: first is %T, second is %T", t, messageRegistry[t](), fn())) } + messageRegistry[t] = fn } func calculateSize(m message) uint32 { @@ -2192,80 +2189,82 @@ var largestFixedSize uint32 // calculateLargestFixedSize is called from within init. func calculateLargestFixedSize() { for _, fn := range messageRegistry { - if size := calculateSize(fn()); size > largestFixedSize { - largestFixedSize = size + if fn != nil { + if size := calculateSize(fn()); size > largestFixedSize { + largestFixedSize = size + } } } } func init() { - register(&Rlerror{}) - register(&Tstatfs{}) - register(&Rstatfs{}) - register(&Tlopen{}) - register(&Rlopen{}) - register(&Tlcreate{}) - register(&Rlcreate{}) - register(&Tsymlink{}) - register(&Rsymlink{}) - register(&Tmknod{}) - register(&Rmknod{}) - register(&Trename{}) - register(&Rrename{}) - register(&Treadlink{}) - register(&Rreadlink{}) - register(&Tgetattr{}) - register(&Rgetattr{}) - register(&Tsetattr{}) - register(&Rsetattr{}) - register(&Txattrwalk{}) - register(&Rxattrwalk{}) - register(&Txattrcreate{}) - register(&Rxattrcreate{}) - register(&Treaddir{}) - register(&Rreaddir{}) - register(&Tfsync{}) - register(&Rfsync{}) - register(&Tlink{}) - register(&Rlink{}) - register(&Tmkdir{}) - register(&Rmkdir{}) - register(&Trenameat{}) - register(&Rrenameat{}) - register(&Tunlinkat{}) - register(&Runlinkat{}) - register(&Tversion{}) - register(&Rversion{}) - register(&Tauth{}) - register(&Rauth{}) - register(&Tattach{}) - register(&Rattach{}) - register(&Tflush{}) - register(&Rflush{}) - register(&Twalk{}) - register(&Rwalk{}) - register(&Tread{}) - register(&Rread{}) - register(&Twrite{}) - register(&Rwrite{}) - register(&Tclunk{}) - register(&Rclunk{}) - register(&Tremove{}) - register(&Rremove{}) - register(&Tflushf{}) - register(&Rflushf{}) - register(&Twalkgetattr{}) - register(&Rwalkgetattr{}) - register(&Tucreate{}) - register(&Rucreate{}) - register(&Tumkdir{}) - register(&Rumkdir{}) - register(&Tumknod{}) - register(&Rumknod{}) - register(&Tusymlink{}) - register(&Rusymlink{}) - register(&Tlconnect{}) - register(&Rlconnect{}) + register(MsgRlerror, func() message { return &Rlerror{} }) + register(MsgTstatfs, func() message { return &Tstatfs{} }) + register(MsgRstatfs, func() message { return &Rstatfs{} }) + register(MsgTlopen, func() message { return &Tlopen{} }) + register(MsgRlopen, func() message { return &Rlopen{} }) + register(MsgTlcreate, func() message { return &Tlcreate{} }) + register(MsgRlcreate, func() message { return &Rlcreate{} }) + register(MsgTsymlink, func() message { return &Tsymlink{} }) + register(MsgRsymlink, func() message { return &Rsymlink{} }) + register(MsgTmknod, func() message { return &Tmknod{} }) + register(MsgRmknod, func() message { return &Rmknod{} }) + register(MsgTrename, func() message { return &Trename{} }) + register(MsgRrename, func() message { return &Rrename{} }) + register(MsgTreadlink, func() message { return &Treadlink{} }) + register(MsgRreadlink, func() message { return &Rreadlink{} }) + register(MsgTgetattr, func() message { return &Tgetattr{} }) + register(MsgRgetattr, func() message { return &Rgetattr{} }) + register(MsgTsetattr, func() message { return &Tsetattr{} }) + register(MsgRsetattr, func() message { return &Rsetattr{} }) + register(MsgTxattrwalk, func() message { return &Txattrwalk{} }) + register(MsgRxattrwalk, func() message { return &Rxattrwalk{} }) + register(MsgTxattrcreate, func() message { return &Txattrcreate{} }) + register(MsgRxattrcreate, func() message { return &Rxattrcreate{} }) + register(MsgTreaddir, func() message { return &Treaddir{} }) + register(MsgRreaddir, func() message { return &Rreaddir{} }) + register(MsgTfsync, func() message { return &Tfsync{} }) + register(MsgRfsync, func() message { return &Rfsync{} }) + register(MsgTlink, func() message { return &Tlink{} }) + register(MsgRlink, func() message { return &Rlink{} }) + register(MsgTmkdir, func() message { return &Tmkdir{} }) + register(MsgRmkdir, func() message { return &Rmkdir{} }) + register(MsgTrenameat, func() message { return &Trenameat{} }) + register(MsgRrenameat, func() message { return &Rrenameat{} }) + register(MsgTunlinkat, func() message { return &Tunlinkat{} }) + register(MsgRunlinkat, func() message { return &Runlinkat{} }) + register(MsgTversion, func() message { return &Tversion{} }) + register(MsgRversion, func() message { return &Rversion{} }) + register(MsgTauth, func() message { return &Tauth{} }) + register(MsgRauth, func() message { return &Rauth{} }) + register(MsgTattach, func() message { return &Tattach{} }) + register(MsgRattach, func() message { return &Rattach{} }) + register(MsgTflush, func() message { return &Tflush{} }) + register(MsgRflush, func() message { return &Rflush{} }) + register(MsgTwalk, func() message { return &Twalk{} }) + register(MsgRwalk, func() message { return &Rwalk{} }) + register(MsgTread, func() message { return &Tread{} }) + register(MsgRread, func() message { return &Rread{} }) + register(MsgTwrite, func() message { return &Twrite{} }) + register(MsgRwrite, func() message { return &Rwrite{} }) + register(MsgTclunk, func() message { return &Tclunk{} }) + register(MsgRclunk, func() message { return &Rclunk{} }) + register(MsgTremove, func() message { return &Tremove{} }) + register(MsgRremove, func() message { return &Rremove{} }) + register(MsgTflushf, func() message { return &Tflushf{} }) + register(MsgRflushf, func() message { return &Rflushf{} }) + register(MsgTwalkgetattr, func() message { return &Twalkgetattr{} }) + register(MsgRwalkgetattr, func() message { return &Rwalkgetattr{} }) + register(MsgTucreate, func() message { return &Tucreate{} }) + register(MsgRucreate, func() message { return &Rucreate{} }) + register(MsgTumkdir, func() message { return &Tumkdir{} }) + register(MsgRumkdir, func() message { return &Rumkdir{} }) + register(MsgTumknod, func() message { return &Tumknod{} }) + register(MsgRumknod, func() message { return &Rumknod{} }) + register(MsgTusymlink, func() message { return &Tusymlink{} }) + register(MsgRusymlink, func() message { return &Rusymlink{} }) + register(MsgTlconnect, func() message { return &Tlconnect{} }) + register(MsgRlconnect, func() message { return &Rlconnect{} }) calculateLargestFixedSize() } diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go index c0d65d82c..68395a396 100644 --- a/pkg/p9/messages_test.go +++ b/pkg/p9/messages_test.go @@ -400,18 +400,20 @@ func TestEncodeDecode(t *testing.T) { func TestMessageStrings(t *testing.T) { for typ, fn := range messageRegistry { - name := fmt.Sprintf("%+v", typ) - t.Run(name, func(t *testing.T) { - defer func() { // Ensure no panic. - if r := recover(); r != nil { - t.Errorf("printing %s failed: %v", name, r) - } - }() - m := fn() - _ = fmt.Sprintf("%v", m) - err := ErrInvalidMsgType{typ} - _ = err.Error() - }) + if fn != nil { + name := fmt.Sprintf("%+v", typ) + t.Run(name, func(t *testing.T) { + defer func() { // Ensure no panic. + if r := recover(); r != nil { + t.Errorf("printing %s failed: %v", name, r) + } + }() + m := fn() + _ = fmt.Sprintf("%v", m) + err := ErrInvalidMsgType{MsgType(typ)} + _ = err.Error() + }) + } } } @@ -424,5 +426,5 @@ func TestRegisterDuplicate(t *testing.T) { }() // Register a duplicate. - register(&Rlerror{}) + register(MsgRlerror, func() message { return &Rlerror{} }) } diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index 3352a5205..b7b7825bd 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -180,5 +180,5 @@ func TestSendClosed(t *testing.T) { } func init() { - register(&badDecode{}) + register(MsgTypeBadDecode, func() message { return &badDecode{} }) } |