summaryrefslogtreecommitdiffhomepage
path: root/tun
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-03-02 15:08:28 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2023-03-10 14:52:17 +0100
commit9e2f3860220280a5630971478b53c8ad9a991ca8 (patch)
tree218f1bd9a8dd649a8fdb50571a921d1ccff4cae5 /tun
parent3bb8fec7e41fcc2138ddb4cba3f46100814fc523 (diff)
conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tun')
-rw-r--r--tun/checksum.go42
-rw-r--r--tun/tcp_offload_linux.go612
-rw-r--r--tun/tcp_offload_linux_test.go273
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b778
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d8
-rw-r--r--tun/tun_linux.go275
6 files changed, 1148 insertions, 70 deletions
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..f4f8471
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,42 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+ i := 0
+ n := len(b)
+ for n >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
+ n -= 4
+ i += 4
+ }
+ for n >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
+ n -= 2
+ i += 2
+ }
+ if n == 1 {
+ ac += uint64(b[i]) << 8
+ }
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go
new file mode 100644
index 0000000..f3ffa75
--- /dev/null
+++ b/tun/tcp_offload_linux.go
@@ -0,0 +1,612 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// flowKey represents the key for a flow.
+type flowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of GRO.
+type tcpGROTable struct {
+ itemsByFlow map[flowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[flowKey][]tcpGROItem, conn.DefaultBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.DefaultBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.DefaultBatchSize)
+ }
+ return t
+}
+
+func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
+ key := flowKey{}
+ addrSize := dstAddr - srcAddr
+ copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
+ copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) ([]tcpGROItem, bool) {
+ key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) {
+ key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ buffsIndex: uint16(buffsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key flowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key flowKey
+ sentSeq uint32 // the sequence number
+ buffsIndex uint16 // the index into the original buffs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, buffs [][]byte, buffsOffset int) canCoalesce {
+ pktTarget := buffs[item.buffsIndex][buffsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if pkt[1] != pktTarget[1] {
+ // cannot coalesce with unequal ToS values
+ return coalesceUnavailable
+ }
+ if pkt[6]>>5 != pktTarget[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ tcpTotalLen := uint16(len(pkt) - int(iphLen))
+ tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
+ return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = 0
+ coalescePSHEnding coalesceResult = 1
+ coalesceItemInvalidCSum coalesceResult = 2
+ coalescePktInvalidCSum coalesceResult = 3
+ coalesceSuccess coalesceResult = 4
+)
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, returning the outcome. This function may swap buffs elements in the
+// event of a prepend as item's buffs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, buffs [][]byte, buffsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(buffs[item.buffsIndex][buffsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-buffsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ buffs[pktBuffsIndex] = append(buffs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(buffs[pktBuffsIndex][buffsOffset+len(pkt):], buffs[item.buffsIndex][buffsOffset+int(headersLen):])
+ // Flip the slice headers in buffs as part of prepend. The index of item
+ // is already being tracked for writing.
+ buffs[item.buffsIndex], buffs[pktBuffsIndex] = buffs[pktBuffsIndex], buffs[item.buffsIndex]
+ } else {
+ pktHead = buffs[item.buffsIndex][buffsOffset:]
+ if cap(pktHead)-buffsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ buffs[item.buffsIndex] = append(buffs[item.buffsIndex], make([]byte, extendBy)...)
+ copy(buffs[item.buffsIndex][buffsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(headersLen),
+ gsoSize: uint16(item.gsoSize),
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+
+ // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
+ // (IPv4) header checksum.
+ if isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pktHead[10], pktHead[11] = 0, 0 // clear checksum field
+ binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
+ iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
+ binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
+ }
+ hdr.encode(buffs[item.buffsIndex][buffsOffset-virtioNetHdrLen:])
+
+ // Calculate the pseudo header checksum and place it at the TCP checksum
+ // offset. Downstream checksum offloading will combine this with computation
+ // of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := buffsOffset + addrOffset
+ srcAddr := buffs[item.buffsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := buffs[item.buffsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments = 0x80
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+// tcpGRO evaluates the TCP packet at pktI in buffs for coalescing with
+// existing packets tracked in table. It will return false when pktI is not
+// coalesced, otherwise true. This indicates to the caller if buffs[pktI]
+// should be written to the Device.
+func tcpGRO(buffs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
+ pkt := buffs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return false
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return false
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return false
+ }
+ if iphLen < 20 || iphLen > 60 {
+ return false
+ }
+ }
+ if len(pkt) < iphLen {
+ return false
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return false
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return false
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) {
+ // no GRO support for fragmented segments for now
+ return false
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return false
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return false
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return false
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, buffs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, buffs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return true
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return false
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return false
+}
+
+func isTCP4(b []byte) bool {
+ if len(b) < 40 {
+ return false
+ }
+ if b[0]>>4 != 4 {
+ return false
+ }
+ if b[9] != unix.IPPROTO_TCP {
+ return false
+ }
+ return true
+}
+
+func isTCP6NoEH(b []byte) bool {
+ if len(b) < 60 {
+ return false
+ }
+ if b[0]>>4 != 6 {
+ return false
+ }
+ if b[6] != unix.IPPROTO_TCP {
+ return false
+ }
+ return true
+}
+
+// handleGRO evaluates buffs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets.
+func handleGRO(buffs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
+ for i := range buffs {
+ if offset < virtioNetHdrLen || offset > len(buffs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var coalesced bool
+ switch {
+ case isTCP4(buffs[i][offset:]):
+ coalesced = tcpGRO(buffs, offset, i, tcp4Table, false)
+ case isTCP6NoEH(buffs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
+ coalesced = tcpGRO(buffs, offset, i, tcp6Table, true)
+ }
+ if !coalesced {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(buffs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ return nil
+}
+
+// tcpTSO splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
+ firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // TCP header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // TCP checksum
+ tcpHLen := int(hdr.hdrLen - hdr.csumStart)
+ tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
+ tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
+ tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go
new file mode 100644
index 0000000..7fa0777
--- /dev/null
+++ b/tun/tcp_offload_linux_test.go
@@ -0,0 +1,273 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipv4H.Encode(&header.IPv4Fields{
+ SrcAddr: tcpip.Address(srcAs4[:]),
+ DstAddr: tcpip.Address(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ })
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipv6H.Encode(&header.IPv6Fields{
+ SrcAddr: tcpip.Address(srcAs16[:]),
+ DstAddr: tcpip.Address(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ })
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.DefaultBatchSize)
+ sizes := make([]int, conn.DefaultBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
+ },
+ []int{0, 2, 3, 5},
+ []int{240, 140, 260, 160},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ []int{0, 1},
+ []int{140, 240},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ []int{0},
+ []int{340},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
new file mode 100644
index 0000000..5461e79
--- /dev/null
+++ b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
@@ -0,0 +1,8 @@
+go test fuzz v1
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+int(34)
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
new file mode 100644
index 0000000..b441819
--- /dev/null
+++ b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
@@ -0,0 +1,8 @@
+go test fuzz v1
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+int(-48)
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 21984ca..d56e3c1 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -17,9 +17,8 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
-
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -33,17 +32,25 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
+ toWrite []int
+ tcp4GROTable, tcp6GROTable *tcpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
- var buf []byte
- if tun.nopi {
- buf = buffs[0][offset:]
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcp4GROTable.reset()
+ tun.tcp6GROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs []error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
- buf = buffs[0][offset-4:]
-
- // add packet information header
- buf[0] = 0x00
- buf[1] = 0x00
- if buf[4]>>4 == ipv6.Version {
- buf[2] = 0x86
- buf[3] = 0xdd
+ for i := range buffs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, buffsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(buffs[buffsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = append(errs, err)
} else {
- buf[2] = 0x08
- buf[3] = 0x00
+ total += n
+ }
+ }
+ return total, ErrorBatch(errs)
+}
+
+// handleVirtioRead splits in into buffs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of buffs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(buffs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:]))
}
+ n := copy(buffs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
+ }
+
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
+
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the TCP header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
+
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
- _, err = tun.tunFile.Write(buf)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- } else if err == nil {
- n = 1
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
- return n, err
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
+
+ return tcpTSO(in, hdr, buffs, sizes, offset)
}
-func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
- case err = <-tun.errors:
+ case err := <-tun.errors:
+ return 0, err
default:
- if tun.nopi {
- sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
- if err == nil {
- n = 1
- }
+ readInto := buffs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], buffs, sizes, offset)
} else {
- buff := buffs[0][offset-4:]
- sizes[0], err = tun.tunFile.Read(buff[:])
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- } else if err == nil {
- n = 1
- }
- if sizes[0] < 4 {
- sizes[0] = 0
- } else {
- sizes[0] -= 4
- }
+ sizes[0] = n
+ return 1, nil
}
}
- return
}
func (tun *NativeTun) Events() <-chan Event {
@@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) BatchSize() int {
- return 1
+ return tun.batchSize
}
+const (
+ // TODO: support TSO with ECN bits
+ tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.DefaultBatchSize
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
@@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- unix.Close(nfd)
- return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- unix.Close(nfd)
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
err = unix.SetNonblock(nfd, true)
@@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcp4GROTable: newTCPGROTable(),
+ tcp6GROTable: newTCPGROTable(),
+ toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
@@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
+ // start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcp4GROTable: newTCPGROTable(),
+ tcp6GROTable: newTCPGROTable(),
+ toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}