diff options
Diffstat (limited to 'tun/tcp_offload_linux_test.go')
-rw-r--r-- | tun/tcp_offload_linux_test.go | 411 |
1 files changed, 0 insertions, 411 deletions
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go deleted file mode 100644 index ddddc48..0000000 --- a/tun/tcp_offload_linux_test.go +++ /dev/null @@ -1,411 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "net/netip" - "testing" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - offset = virtioNetHdrLen -) - -var ( - ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") - ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") - ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") - ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") - ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") - ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") -) - -func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { - totalLen := 40 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv4H := header.IPv4(b[offset:]) - srcAs4 := srcIPPort.Addr().As4() - dstAs4 := dstIPPort.Addr().As4() - ipFields := &header.IPv4Fields{ - SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), - DstAddr: tcpip.AddrFromSlice(dstAs4[:]), - Protocol: unix.IPPROTO_TCP, - TTL: 64, - TotalLength: uint16(totalLen), - } - if ipFn != nil { - ipFn(ipFields) - } - ipv4H.Encode(ipFields) - tcpH := header.TCP(b[offset+20:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) -} - -func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { - totalLen := 60 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv6H := header.IPv6(b[offset:]) - srcAs16 := srcIPPort.Addr().As16() - dstAs16 := dstIPPort.Addr().As16() - ipFields := &header.IPv6Fields{ - SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), - DstAddr: tcpip.AddrFromSlice(dstAs16[:]), - TransportProtocol: unix.IPPROTO_TCP, - HopLimit: 64, - PayloadLength: uint16(segmentSize + 20), - } - if ipFn != nil { - ipFn(ipFields) - } - ipv6H.Encode(ipFields) - tcpH := header.TCP(b[offset+40:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) -} - -func Test_handleVirtioRead(t *testing.T) { - tests := []struct { - name string - hdr virtioNetHdr - pktIn []byte - wantLens []int - wantErr bool - }{ - { - "tcp4", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, - gsoSize: 100, - hdrLen: 40, - csumStart: 20, - csumOffset: 16, - }, - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{140, 140}, - false, - }, - { - "tcp6", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, - gsoSize: 100, - hdrLen: 60, - csumStart: 40, - csumOffset: 16, - }, - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } - tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if n != len(tt.wantLens) { - t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) - } - for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) - } - } - }) - } -} - -func flipTCP4Checksum(b []byte) []byte { - at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 - b[at] ^= 0xFF - b[at+1] ^= 0xFF - return b -} - -func Fuzz_handleGRO(f *testing.F) { - pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) - pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) - pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) - pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) - pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) - pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { - pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) - } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) - } - seenWriteI[writeI] = true - } - }) -} - -func Test_handleGRO(t *testing.T) { - tests := []struct { - name string - pktsIn [][]byte - wantToWrite []int - wantLens []int - wantErr bool - }{ - { - "multiple flows", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 - }, - []int{0, 2, 3, 5}, - []int{240, 140, 260, 160}, - false, - }, - { - "PSH interleaved", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 - }, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, - false, - }, - { - "coalesceItemInvalidCSum", - [][]byte{ - flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0, 1}, - []int{140, 240}, - false, - }, - { - "out of order", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0}, - []int{340}, - false, - }, - { - "tcp4 unequal TTL", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TTL++ - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal ToS", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TOS++ - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal flags more fragments set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 1 - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal flags DF set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 2 - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp6 unequal hop limit", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.HopLimit++ - }), - }, - []int{0, 1}, - []int{160, 160}, - false, - }, - { - "tcp6 unequal traffic class", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.TrafficClass++ - }), - }, - []int{0, 1}, - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) - } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) - } - } - }) - } -} - -func Test_isTCP4NoIPOptions(t *testing.T) { - valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] - invalidLen := valid[:39] - invalidHeaderLen := make([]byte, len(valid)) - copy(invalidHeaderLen, valid) - invalidHeaderLen[0] = 0x46 - invalidProtocol := make([]byte, len(valid)) - copy(invalidProtocol, valid) - invalidProtocol[9] = unix.IPPROTO_TCP + 1 - - tests := []struct { - name string - b []byte - want bool - }{ - { - "valid", - valid, - true, - }, - { - "invalid length", - invalidLen, - false, - }, - { - "invalid version", - []byte{0x00}, - false, - }, - { - "invalid header len", - invalidHeaderLen, - false, - }, - { - "invalid protocol", - invalidProtocol, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isTCP4NoIPOptions(tt.b); got != tt.want { - t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) - } - }) - } -} |