diff options
Diffstat (limited to 'pkg/tcpip/link/fdbased/endpoint_test.go')
-rw-r--r-- | pkg/tcpip/link/fdbased/endpoint_test.go | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go new file mode 100644 index 000000000..f7bbb28e1 --- /dev/null +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -0,0 +1,336 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fdbased + +import ( + "fmt" + "math/rand" + "reflect" + "syscall" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +type packetInfo struct { + raddr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + contents buffer.View +} + +type context struct { + t *testing.T + fds [2]int + ep stack.LinkEndpoint + ch chan packetInfo + done chan struct{} +} + +func newContext(t *testing.T, opt *Options) *context { + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + t.Fatalf("Socketpair failed: %v", err) + } + + done := make(chan struct{}, 1) + opt.ClosedFunc = func(*tcpip.Error) { + done <- struct{}{} + } + + opt.FD = fds[1] + ep := stack.FindLinkEndpoint(New(opt)).(*endpoint) + + c := &context{ + t: t, + fds: fds, + ep: ep, + ch: make(chan packetInfo, 100), + done: done, + } + + ep.Attach(c) + + return c +} + +func (c *context) cleanup() { + syscall.Close(c.fds[0]) + <-c.done + syscall.Close(c.fds[1]) +} + +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()} +} + +func TestNoEthernetProperties(t *testing.T) { + const mtu = 1500 + c := newContext(t, &Options{MTU: mtu}) + defer c.cleanup() + + if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { + t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) + } + + if want, v := uint32(mtu), c.ep.MTU(); want != v { + t.Fatalf("MTU() = %v, want %v", v, want) + } +} + +func TestEthernetProperties(t *testing.T) { + const mtu = 1500 + c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) + defer c.cleanup() + + if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { + t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) + } + + if want, v := uint32(mtu), c.ep.MTU(); want != v { + t.Fatalf("MTU() = %v, want %v", v, want) + } +} + +func TestAddress(t *testing.T) { + const mtu = 1500 + addrs := []tcpip.LinkAddress{"", "abc", "def"} + for _, a := range addrs { + t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { + c := newContext(t, &Options{Address: a, MTU: mtu}) + defer c.cleanup() + + if want, v := a, c.ep.LinkAddress(); want != v { + t.Fatalf("LinkAddress() = %v, want %v", v, want) + } + }) + } +} + +func TestWritePacket(t *testing.T) { + const ( + mtu = 1500 + laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") + raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") + proto = 10 + ) + + lengths := []int{0, 100, 1000} + eths := []bool{true, false} + + for _, eth := range eths { + for _, plen := range lengths { + t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { + c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) + defer c.cleanup() + + r := &stack.Route{ + RemoteLinkAddress: raddr, + } + + // Build header. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) + b := hdr.Prepend(100) + for i := range b { + b[i] = uint8(rand.Intn(256)) + } + + // Buiild payload and write. + payload := make([]byte, plen) + for i := range payload { + payload[i] = uint8(rand.Intn(256)) + } + want := append(hdr.UsedBytes(), payload...) + if err := c.ep.WritePacket(r, &hdr, payload, proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } + + // Read from fd, then compare with what we wrote. + b = make([]byte, mtu) + n, err := syscall.Read(c.fds[0], b) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + b = b[:n] + if eth { + h := header.Ethernet(b) + b = b[header.EthernetMinimumSize:] + + if a := h.SourceAddress(); a != laddr { + t.Fatalf("SourceAddress() = %v, want %v", a, laddr) + } + + if a := h.DestinationAddress(); a != raddr { + t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) + } + + if et := h.Type(); et != proto { + t.Fatalf("Type() = %v, want %v", et, proto) + } + } + if len(b) != len(want) { + t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) + } + if !reflect.DeepEqual(b, want) { + t.Fatalf("Read returned %x, want %x", b, want) + } + }) + } + } +} + +func TestDeliverPacket(t *testing.T) { + const ( + mtu = 1500 + laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") + raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") + proto = 10 + ) + + lengths := []int{100, 1000} + eths := []bool{true, false} + + for _, eth := range eths { + for _, plen := range lengths { + t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { + c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) + defer c.cleanup() + + // Build packet. + b := make([]byte, plen) + all := b + for i := range b { + b[i] = uint8(rand.Intn(256)) + } + + if !eth { + // So that it looks like an IPv4 packet. + b[0] = 0x40 + } else { + hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr.Encode(&header.EthernetFields{ + SrcAddr: raddr, + DstAddr: laddr, + Type: proto, + }) + all = append(hdr, b...) + } + + // Write packet via the file descriptor. + if _, err := syscall.Write(c.fds[0], all); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Receive packet through the endpoint. + select { + case pi := <-c.ch: + want := packetInfo{ + raddr: raddr, + proto: proto, + contents: b, + } + if !eth { + want.proto = header.IPv4ProtocolNumber + want.raddr = "" + } + if !reflect.DeepEqual(want, pi) { + t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + } + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for packet") + } + }) + } + } +} + +func TestBufConfigMaxLength(t *testing.T) { + got := 0 + for _, i := range BufConfig { + got += i + } + want := header.MaxIPPacketSize // maximum TCP packet size + if got < want { + t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) + } +} + +func TestBufConfigFirst(t *testing.T) { + // The stack assumes that the TCP/IP header is enterily contained in the first view. + // Therefore, the first view needs to be large enough to contain the maximum TCP/IP + // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). + want := 120 + got := BufConfig[0] + if got < want { + t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) + } +} + +func build(bufConfig []int) *endpoint { + e := &endpoint{ + views: make([]buffer.View, len(bufConfig)), + iovecs: make([]syscall.Iovec, len(bufConfig)), + } + e.allocateViews(bufConfig) + return e +} + +var capLengthTestCases = []struct { + comment string + config []int + n int + wantUsed int + wantLengths []int +}{ + { + comment: "Single slice", + config: []int{2}, + n: 1, + wantUsed: 1, + wantLengths: []int{1}, + }, + { + comment: "Multiple slices", + config: []int{1, 2}, + n: 2, + wantUsed: 2, + wantLengths: []int{1, 1}, + }, + { + comment: "Entire buffer", + config: []int{1, 2}, + n: 3, + wantUsed: 2, + wantLengths: []int{1, 2}, + }, + { + comment: "Entire buffer but not on the last slice", + config: []int{1, 2, 3}, + n: 3, + wantUsed: 2, + wantLengths: []int{1, 2, 3}, + }, +} + +func TestCapLength(t *testing.T) { + for _, c := range capLengthTestCases { + e := build(c.config) + used := e.capViews(c.n, c.config) + if used != c.wantUsed { + t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) + } + lengths := make([]int, len(e.views)) + for i, v := range e.views { + lengths[i] = len(v) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) + } + + } +} |