summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs/host/socket_test.go
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/sentry/fs/host/socket_test.go
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/sentry/fs/host/socket_test.go')
-rw-r--r--pkg/sentry/fs/host/socket_test.go401
1 files changed, 401 insertions, 0 deletions
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
new file mode 100644
index 000000000..80c46dcfa
--- /dev/null
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -0,0 +1,401 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "reflect"
+ "syscall"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier"
+)
+
+var (
+ // Make sure that connectedEndpoint implements unix.ConnectedEndpoint.
+ _ = unix.ConnectedEndpoint(new(connectedEndpoint))
+
+ // Make sure that connectedEndpoint implements unix.Receiver.
+ _ = unix.Receiver(new(connectedEndpoint))
+)
+
+func getFl(fd int) (uint32, error) {
+ fl, _, err := syscall.RawSyscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0)
+ if err == 0 {
+ return uint32(fl), nil
+ }
+ return 0, err
+}
+
+func TestSocketIsBlocking(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+
+ fl, err := getFl(pair[0])
+ if err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
+ }
+ if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
+ t.Fatalf("Expected socket %v to be blocking", pair[0])
+ }
+ if fl, err = getFl(pair[1]); err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err)
+ }
+ if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
+ t.Fatalf("Expected socket %v to be blocking", pair[1])
+ }
+ sock, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) failed => %v", pair[0], err)
+ }
+ defer sock.DecRef()
+ // Test that the socket now is non blocking.
+ if fl, err = getFl(pair[0]); err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
+ }
+ if fl&syscall.O_NONBLOCK != syscall.O_NONBLOCK {
+ t.Errorf("Expected socket %v to have becoming non blocking", pair[0])
+ }
+ if fl, err = getFl(pair[1]); err != nil {
+ t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err)
+ }
+ if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
+ t.Errorf("Did not expect socket %v to become non blocking", pair[1])
+ }
+}
+
+func TestSocketWritev(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+ socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer socket.DecRef()
+ buf := []byte("hello world\n")
+ n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf))
+ if err != nil {
+ t.Fatalf("socket writev failed: %v", err)
+ }
+
+ if n != int64(len(buf)) {
+ t.Fatalf("socket writev wrote incorrect bytes: %d", n)
+ }
+}
+
+func TestSocketWritevLen0(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+ socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer socket.DecRef()
+ n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil))
+ if err != nil {
+ t.Fatalf("socket writev failed: %v", err)
+ }
+
+ if n != 0 {
+ t.Fatalf("socket writev wrote incorrect bytes: %d", n)
+ }
+}
+
+func TestSocketSendMsgLen0(t *testing.T) {
+ // Using socketpair here because it's already connected.
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("host socket creation failed: %v", err)
+ }
+ sfile, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer sfile.DecRef()
+
+ s := sfile.FileOperations.(socket.Socket)
+ n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, unix.ControlMessages{})
+ if n != 0 {
+ t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n)
+ }
+
+ if terr != nil {
+ t.Fatalf("socket sendmsg() failed: %v", terr)
+ }
+}
+
+func TestListen(t *testing.T) {
+ pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
+ }
+ sfile1, err := newSocket(contexttest.Context(t), pair[0], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[0], err)
+ }
+ defer sfile1.DecRef()
+ socket1 := sfile1.FileOperations.(socket.Socket)
+
+ sfile2, err := newSocket(contexttest.Context(t), pair[1], false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", pair[1], err)
+ }
+ defer sfile2.DecRef()
+ socket2 := sfile2.FileOperations.(socket.Socket)
+
+ // Socketpairs can not be listened to.
+ if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
+ t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
+ }
+ if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
+ t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
+ }
+
+ // Create a Unix socket, do not bind it.
+ sock, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
+ }
+ sfile3, err := newSocket(contexttest.Context(t), sock, false)
+ if err != nil {
+ t.Fatalf("newSocket(%v) => %v", sock, err)
+ }
+ defer sfile3.DecRef()
+ socket3 := sfile3.FileOperations.(socket.Socket)
+
+ // This socket is not bound so we can't listen on it.
+ if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState {
+ t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err)
+ }
+}
+
+func TestSend(t *testing.T) {
+ e := connectedEndpoint{writeClosed: true}
+ if _, _, err := e.Send(nil, unix.ControlMessages{}, tcpip.FullAddress{}); err != tcpip.ErrClosedForSend {
+ t.Errorf("Got %#v.Send() = %v, want = %v", e, err, tcpip.ErrClosedForSend)
+ }
+}
+
+func TestRecv(t *testing.T) {
+ e := connectedEndpoint{readClosed: true}
+ if _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != tcpip.ErrClosedForReceive {
+ t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, tcpip.ErrClosedForReceive)
+ }
+}
+
+func TestPasscred(t *testing.T) {
+ e := connectedEndpoint{}
+ if got, want := e.Passcred(), false; got != want {
+ t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
+ }
+}
+
+func TestGetLocalAddress(t *testing.T) {
+ e := connectedEndpoint{path: "foo"}
+ want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
+ if got, err := e.GetLocalAddress(); err != nil || got != want {
+ t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
+ }
+}
+
+func TestQueuedSize(t *testing.T) {
+ e := connectedEndpoint{}
+ tests := []struct {
+ name string
+ f func() int64
+ }{
+ {"SendQueuedSize", e.SendQueuedSize},
+ {"RecvQueuedSize", e.RecvQueuedSize},
+ }
+
+ for _, test := range tests {
+ if got, want := test.f(), int64(-1); got != want {
+ t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want)
+ }
+ }
+}
+
+func TestReadable(t *testing.T) {
+ e := connectedEndpoint{readClosed: true}
+ if got, want := e.Readable(), true; got != want {
+ t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want)
+ }
+}
+
+func TestWritable(t *testing.T) {
+ e := connectedEndpoint{writeClosed: true}
+ if got, want := e.Writable(), true; got != want {
+ t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want)
+ }
+}
+
+func TestRelease(t *testing.T) {
+ f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ want := &connectedEndpoint{queue: c.queue}
+ want.ref.DecRef()
+ fdnotifier.AddFD(int32(c.file.FD()), nil)
+ c.Release()
+ if !reflect.DeepEqual(c, want) {
+ t.Errorf("got = %#v, want = %#v", c, want)
+ }
+}
+
+func TestClose(t *testing.T) {
+ type testCase struct {
+ name string
+ cep *connectedEndpoint
+ addFD bool
+ f func()
+ want *connectedEndpoint
+ }
+
+ var tests []testCase
+
+ // nil is the value used by connectedEndpoint to indicate a closed file.
+ // Non-nil files are used to check if the file gets closed.
+
+ f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c := &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ tests = append(tests, testCase{
+ name: "First CloseRecv",
+ cep: c,
+ addFD: false,
+ f: c.CloseRecv,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
+ tests = append(tests, testCase{
+ name: "Second CloseRecv",
+ cep: c,
+ addFD: false,
+ f: c.CloseRecv,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
+ tests = append(tests, testCase{
+ name: "First CloseSend",
+ cep: c,
+ addFD: false,
+ f: c.CloseSend,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
+ tests = append(tests, testCase{
+ name: "Second CloseSend",
+ cep: c,
+ addFD: false,
+ f: c.CloseSend,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
+ tests = append(tests, testCase{
+ name: "CloseSend then CloseRecv",
+ cep: c,
+ addFD: true,
+ f: c.CloseRecv,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
+ tests = append(tests, testCase{
+ name: "CloseRecv then CloseSend",
+ cep: c,
+ addFD: true,
+ f: c.CloseSend,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
+ tests = append(tests, testCase{
+ name: "Full close then CloseRecv",
+ cep: c,
+ addFD: false,
+ f: c.CloseRecv,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ })
+
+ f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ t.Fatal("Creating socket:", err)
+ }
+ c = &connectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
+ tests = append(tests, testCase{
+ name: "Full close then CloseSend",
+ cep: c,
+ addFD: false,
+ f: c.CloseSend,
+ want: &connectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
+ })
+
+ for _, test := range tests {
+ if test.addFD {
+ fdnotifier.AddFD(int32(test.cep.file.FD()), nil)
+ }
+ if test.f(); !reflect.DeepEqual(test.cep, test.want) {
+ t.Errorf("%s: got = %#v, want = %#v", test.name, test.cep, test.want)
+ }
+ }
+}