summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go125
1 files changed, 125 insertions, 0 deletions
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index af32d3009..7afc4702c 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -23,6 +23,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
@@ -32,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -456,3 +458,126 @@ func TestGetLinkAddress(t *testing.T) {
})
}
}
+
+func TestWritePacketsLinkResolution(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedWriteErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
+ }
+ defer serverEP.Close()
+
+ serverAddr := tcpip.FullAddress{Port: 1234}
+ if err := serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ data := []byte{1, 2}
+ var pkts stack.PacketBufferList
+ for _, d := range data {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: buffer.View([]byte{d}).ToVectorisedView(),
+ })
+ pkt.TransportProtocolNumber = udp.ProtocolNumber
+ length := uint16(pkt.Size())
+ udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: serverAddr.Port,
+ Length: length,
+ })
+ xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+
+ pkts.PushBack(pkt)
+ }
+
+ params := stack.NetworkHeaderParams{
+ Protocol: udp.ProtocolNumber,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+
+ if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil {
+ t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err)
+ } else if want := pkts.Len(); want != n {
+ t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want)
+ }
+
+ var writer bytes.Buffer
+ count := 0
+ for {
+ var rOpts tcpip.ReadOptions
+ res, err := serverEP.Read(&writer, rOpts)
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Should not have anymore bytes to read after we read the sent
+ // number of bytes.
+ if count == len(data) {
+ break
+ }
+
+ <-serverCH
+ continue
+ }
+
+ t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
+ }
+ count += res.Count
+ }
+
+ if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
+ t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
+ t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}