summaryrefslogtreecommitdiffhomepage
path: root/pkg/p9/transport_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/p9/transport_test.go')
-rw-r--r--pkg/p9/transport_test.go184
1 files changed, 184 insertions, 0 deletions
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
new file mode 100644
index 000000000..e3ee3e9bd
--- /dev/null
+++ b/pkg/p9/transport_test.go
@@ -0,0 +1,184 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "io/ioutil"
+ "os"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+const (
+ MsgTypeBadEncode = iota + 252
+ MsgTypeBadDecode
+ MsgTypeUnregistered
+)
+
+func TestSendRecv(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &Tlopen{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ tag, m, err := recv(server, maximumLength, messageByType)
+ if err != nil {
+ t.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ t.Fatalf("got tag %v expected 1", tag)
+ }
+ if _, ok := m.(*Tlopen); !ok {
+ t.Fatalf("got message %v expected *Tlopen", m)
+ }
+}
+
+// badDecode overruns on decode.
+type badDecode struct{}
+
+func (*badDecode) Decode(b *buffer) { b.markOverrun() }
+func (*badDecode) Encode(b *buffer) {}
+func (*badDecode) Type() MsgType { return MsgTypeBadDecode }
+func (*badDecode) String() string { return "badDecode{}" }
+
+func TestRecvOverrun(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &badDecode{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ if _, _, err := recv(server, maximumLength, messageByType); err != ErrNoValidMessage {
+ t.Fatalf("recv got err %v expected ErrNoValidMessage", err)
+ }
+}
+
+// unregistered is not registered on decode.
+type unregistered struct{}
+
+func (*unregistered) Decode(b *buffer) {}
+func (*unregistered) Encode(b *buffer) {}
+func (*unregistered) Type() MsgType { return MsgTypeUnregistered }
+func (*unregistered) String() string { return "unregistered{}" }
+
+func TestRecvInvalidType(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ if err := send(client, Tag(1), &unregistered{}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ _, _, err = recv(server, maximumLength, messageByType)
+ if _, ok := err.(*ErrInvalidMsgType); !ok {
+ t.Fatalf("recv got err %v expected ErrInvalidMsgType", err)
+ }
+}
+
+func TestSendRecvWithFile(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ defer client.Close()
+
+ // Create a tempfile.
+ osf, err := ioutil.TempFile("", "p9")
+ if err != nil {
+ t.Fatalf("tempfile got err %v expected nil", err)
+ }
+ os.Remove(osf.Name())
+ f, err := fd.NewFromFile(osf)
+ osf.Close()
+ if err != nil {
+ t.Fatalf("unable to create file: %v", err)
+ }
+
+ if err := send(client, Tag(1), &Rlopen{File: f}); err != nil {
+ t.Fatalf("send got err %v expected nil", err)
+ }
+
+ // Enable withFile.
+ tag, m, err := recv(server, maximumLength, messageByType)
+ if err != nil {
+ t.Fatalf("recv got err %v expected nil", err)
+ }
+ if tag != Tag(1) {
+ t.Fatalf("got tag %v expected 1", tag)
+ }
+ rlopen, ok := m.(*Rlopen)
+ if !ok {
+ t.Fatalf("got m %v expected *Rlopen", m)
+ }
+ if rlopen.File == nil {
+ t.Fatalf("got nil file expected non-nil")
+ }
+}
+
+func TestRecvClosed(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer server.Close()
+ client.Close()
+
+ _, _, err = recv(server, maximumLength, messageByType)
+ if err == nil {
+ t.Fatalf("got err nil expected non-nil")
+ }
+ if _, ok := err.(ErrSocket); !ok {
+ t.Fatalf("got err %v expected ErrSocket", err)
+ }
+}
+
+func TestSendClosed(t *testing.T) {
+ server, client, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ server.Close()
+ defer client.Close()
+
+ err = send(client, Tag(1), &Tlopen{})
+ if err == nil {
+ t.Fatalf("send got err nil expected non-nil")
+ }
+ if _, ok := err.(ErrSocket); !ok {
+ t.Fatalf("got err %v expected ErrSocket", err)
+ }
+}
+
+func init() {
+ register(&badDecode{})
+}