// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package connection_test import ( "reflect" "testing" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/lisafs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) const ( dynamicMsgID = lisafs.Channel + 1 versionMsgID = dynamicMsgID + 1 ) var handlers = [...]lisafs.RPCHandler{ lisafs.Error: lisafs.ErrorHandler, lisafs.Mount: lisafs.MountHandler, lisafs.Channel: lisafs.ChannelHandler, dynamicMsgID: dynamicMsgHandler, versionMsgID: versionHandler, } // testServer implements lisafs.ServerImpl. type testServer struct { lisafs.Server } var _ lisafs.ServerImpl = (*testServer)(nil) type testControlFD struct { lisafs.ControlFD lisafs.ControlFDImpl } func (fd *testControlFD) FD() *lisafs.ControlFD { return &fd.ControlFD } // Mount implements lisafs.Mount. func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) { return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil } // MaxMessageSize implements lisafs.MaxMessageSize. func (s *testServer) MaxMessageSize() uint32 { return lisafs.MaxMessageSize() } // SupportedMessages implements lisafs.ServerImpl.SupportedMessages. func (s *testServer) SupportedMessages() []lisafs.MID { return []lisafs.MID{ lisafs.Mount, lisafs.Channel, dynamicMsgID, versionMsgID, } } func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) { serverSocket, clientSocket, err := unet.SocketPair(false) if err != nil { t.Fatalf("socketpair got err %v expected nil", err) } ts := &testServer{} ts.Server.InitTestOnly(ts, handlers[:]) conn, err := ts.CreateConnection(serverSocket, false /* readonly */) if err != nil { t.Fatalf("starting connection failed: %v", err) return } ts.StartConnection(conn) c, _, err := lisafs.NewClient(clientSocket, "/") if err != nil { t.Fatalf("client creation failed: %v", err) } clientFn(c) c.Close() // This should trigger client and server shutdown. ts.Wait() } // TestStartUp tests that the server and client can be started up correctly. func TestStartUp(t *testing.T) { runServerClient(t, func(c *lisafs.Client) { if c.IsSupported(lisafs.Error) { t.Errorf("sending error messages should not be supported") } }) } func TestUnsupportedMessage(t *testing.T) { unsupportedM := lisafs.MID(len(handlers)) runServerClient(t, func(c *lisafs.Client) { if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP { t.Errorf("expected EOPNOTSUPP but got err: %v", err) } }) } func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { var req lisafs.MsgDynamic req.UnmarshalBytes(comm.PayloadBuf(payloadLen)) // Just echo back the message. respPayloadLen := uint32(req.SizeBytes()) req.MarshalBytes(comm.PayloadBuf(respPayloadLen)) return respPayloadLen, nil } // TestStress stress tests sending many messages from various goroutines. func TestStress(t *testing.T) { runServerClient(t, func(c *lisafs.Client) { concurrency := 8 numMsgPerGoroutine := 5000 var clientWg sync.WaitGroup for i := 0; i < concurrency; i++ { clientWg.Add(1) go func() { defer clientWg.Done() for j := 0; j < numMsgPerGoroutine; j++ { // Create a massive random message. var req lisafs.MsgDynamic req.Randomize(100) var resp lisafs.MsgDynamic if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil { t.Errorf("SndRcvMessage: received unexpected error %v", err) return } if !reflect.DeepEqual(&req, &resp) { t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp) } } }() } clientWg.Wait() }) } func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) { // To be fair, usually handlers will create their own objects and return a // pointer to those. Might be tempting to reuse above variables, but don't. var rv lisafs.P9Version rv.UnmarshalBytes(comm.PayloadBuf(payloadLen)) // Create a new response. sv := lisafs.P9Version{ MSize: rv.MSize, Version: "9P2000.L.Google.11", } respPayloadLen := uint32(sv.SizeBytes()) sv.MarshalBytes(comm.PayloadBuf(respPayloadLen)) return respPayloadLen, nil } // BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel. func BenchmarkSendRecv(b *testing.B) { b.ReportAllocs() sendV := lisafs.P9Version{ MSize: 1 << 20, Version: "9P2000.L.Google.12", } var recvV lisafs.P9Version runServerClient(b, func(c *lisafs.Client) { for i := 0; i < b.N; i++ { if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil { b.Fatalf("unexpected error occurred: %v", err) } } }) }