diff options
Diffstat (limited to 'pkg/lisafs/sock_test.go')
-rw-r--r-- | pkg/lisafs/sock_test.go | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go new file mode 100644 index 000000000..387f4b7a8 --- /dev/null +++ b/pkg/lisafs/sock_test.go @@ -0,0 +1,217 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) { + sock1, sock2, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("socketpair got err %v expected nil", err) + } + defer sock1.Close() + defer sock2.Close() + + var testWg sync.WaitGroup + testWg.Add(2) + + go func() { + fun1(newSockComm(sock1)) + testWg.Done() + }() + + go func() { + fun2(newSockComm(sock2)) + testWg.Done() + }() + + testWg.Wait() +} + +func TestReadWrite(t *testing.T) { + // Create random data to send. + n := 10000 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + runSocketTest(t, func(comm *sockCommunicator) { + // Scatter that data into two parts using Iovecs while sending. + mid := n / 2 + if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + if _, err := readFrom(comm.sock, gotData, 0); err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + + // Make sure we got the right data. + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + }) +} + +func TestFDDonation(t *testing.T) { + n := 10 + data := make([]byte, n) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand.Read(data) failed: %v", err) + } + + // Try donating FDs to these files. + path1 := "/dev/null" + path2 := "/dev" + path3 := "/dev/random" + + runSocketTest(t, func(comm *sockCommunicator) { + devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0) + defer unix.Close(devNullFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path1, err) + } + devFD, err := unix.Open(path2, unix.O_RDONLY, 0) + defer unix.Close(devFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0) + defer unix.Close(devRandomFD) + if err != nil { + t.Fatalf("open(%s) failed: %v", path2, err) + } + if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil { + t.Errorf("writeTo socket failed: %v", err) + } + }, func(comm *sockCommunicator) { + gotData := make([]byte, n) + fds, err := readFrom(comm.sock, gotData, 3) + if err != nil { + t.Fatalf("reading from socket failed: %v", err) + } + defer closeFDs(fds[:]) + + if res := bytes.Compare(data, gotData); res != 0 { + t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) + } + + if len(fds) != 3 { + t.Fatalf("wanted 3 FD, got %d", len(fds)) + } + + // Check that the FDs actually point to the correct file. + compareFDWithFile(t, fds[0], path1) + compareFDWithFile(t, fds[1], path2) + compareFDWithFile(t, fds[2], path3) + }) +} + +func compareFDWithFile(t *testing.T, fd int, path string) { + var want unix.Stat_t + if err := unix.Stat(path, &want); err != nil { + t.Fatalf("stat(%s) failed: %v", path, err) + } + + var got unix.Stat_t + if err := unix.Fstat(fd, &got); err != nil { + t.Fatalf("fstat on donated FD failed: %v", err) + } + + if got.Ino != want.Ino || got.Dev != want.Dev { + t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got) + } +} + +func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error { + var payloadLen uint32 + if msg != nil { + payloadLen = uint32(msg.SizeBytes()) + msg.MarshalUnsafe(comm.PayloadBuf(payloadLen)) + } + return comm.sndPrepopulatedMsg(m, payloadLen, nil) +} + +func TestSndRcvMessage(t *testing.T) { + req := &MsgSimple{} + req.Randomize() + reqM := MID(1) + + // Create a massive random response. + var resp MsgDynamic + resp.Randomize(100) + respM := MID(2) + + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, req); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, &resp) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, req) + if err := testSndMsg(comm, respM, &resp); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func TestSndRcvMessageNoPayload(t *testing.T) { + reqM := MID(1) + respM := MID(2) + runSocketTest(t, func(comm *sockCommunicator) { + if err := testSndMsg(comm, reqM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + checkMessageReceive(t, comm, respM, nil) + }, func(comm *sockCommunicator) { + checkMessageReceive(t, comm, reqM, nil) + if err := testSndMsg(comm, respM, nil); err != nil { + t.Errorf("writeMessageTo failed: %v", err) + } + }) +} + +func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) { + gotM, payloadLen, err := comm.rcvMsg(0) + if err != nil { + t.Fatalf("readMessageFrom failed: %v", err) + } + if gotM != wantM { + t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM) + } + if wantMsg == nil { + if payloadLen != 0 { + t.Errorf("no payload expect but got %d bytes", payloadLen) + } + } else { + gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable) + gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) + if !reflect.DeepEqual(wantMsg, gotMsg) { + t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg) + } + } +} |