summaryrefslogtreecommitdiffhomepage
path: root/tun/offload_linux_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/offload_linux_test.go')
-rw-r--r--tun/offload_linux_test.go752
1 files changed, 752 insertions, 0 deletions
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644
index 0000000..ae55c8c
--- /dev/null
+++ b/tun/offload_linux_test.go
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 28 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_UDP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ udpH := header.UDP(b[offset+20:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 48 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_UDP,
+ HopLimit: 64,
+ PayloadLength: uint16(payloadLen + udphLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ udpH := header.UDP(b[offset+40:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ {
+ "udp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 28,
+ csumStart: 20,
+ csumOffset: 6,
+ },
+ udp4Packet(ip4PortA, ip4PortB, 200),
+ []int{128, 128},
+ false,
+ },
+ {
+ "udp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 48,
+ csumStart: 40,
+ csumOffset: 6,
+ },
+ udp6Packet(ip6PortA, ip6PortB, 200),
+ []int{148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.IdealBatchSize)
+ sizes := make([]int, conn.IdealBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+ pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ canUDPGRO bool
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple protocols and flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ true,
+ []int{0, 1, 2, 4, 5, 7, 9},
+ []int{240, 228, 128, 140, 260, 160, 248},
+ false,
+ },
+ {
+ "multiple protocols and flows no UDP GRO",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ false,
+ []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+ []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ true,
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ },
+ true,
+ []int{0, 1, 3, 4},
+ []int{140, 240, 128, 228},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ true,
+ []int{0},
+ []int{340},
+ false,
+ },
+ {
+ "unequal TTL",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal ToS",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags more fragments set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags DF set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "ipv6 unequal hop limit",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ {
+ "ipv6 unequal traffic class",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+ tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp4TooShort := tcp4[:39]
+ ip4InvalidHeaderLen := make([]byte, len(tcp4))
+ copy(ip4InvalidHeaderLen, tcp4)
+ ip4InvalidHeaderLen[0] = 0x46
+ ip4InvalidProtocol := make([]byte, len(tcp4))
+ copy(ip4InvalidProtocol, tcp4)
+ ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+ tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp6TooShort := tcp6[:59]
+ ip6InvalidProtocol := make([]byte, len(tcp6))
+ copy(ip6InvalidProtocol, tcp6)
+ ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+ udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+ udp4TooShort := udp4[:27]
+
+ udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+ udp6TooShort := udp6[:47]
+
+ tests := []struct {
+ name string
+ b []byte
+ canUDPGRO bool
+ want groCandidateType
+ }{
+ {
+ "tcp4",
+ tcp4,
+ true,
+ tcp4GROCandidate,
+ },
+ {
+ "tcp6",
+ tcp6,
+ true,
+ tcp6GROCandidate,
+ },
+ {
+ "udp4",
+ udp4,
+ true,
+ udp4GROCandidate,
+ },
+ {
+ "udp4 no support",
+ udp4,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp6",
+ udp6,
+ true,
+ udp6GROCandidate,
+ },
+ {
+ "udp6 no support",
+ udp6,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp4 too short",
+ udp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "udp6 too short",
+ udp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp4 too short",
+ tcp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp6 too short",
+ tcp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP version",
+ []byte{0x00},
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP header len",
+ ip4InvalidHeaderLen,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip4 invalid protocol",
+ ip4InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip6 invalid protocol",
+ ip6InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+ t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+ udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+ type args struct {
+ pkt []byte
+ iphLen uint8
+ gsoSize uint16
+ item udpGROItem
+ bufs [][]byte
+ bufsOffset int
+ }
+ tests := []struct {
+ name string
+ args args
+ want canCoalesce
+ }{
+ {
+ "coalesceAppend equal gso",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceAppend smaller gso",
+ args{
+ pkt: udp4a[offset : len(udp4a)-90],
+ iphLen: 20,
+ gsoSize: 10,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceUnavailable smaller gso previously appended",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4c,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ {
+ "coalesceUnavailable larger following smaller",
+ args{
+ pkt: udp4c[offset:],
+ iphLen: 20,
+ gsoSize: 110,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4c,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+ t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}