diff options
Diffstat (limited to 'tun')
-rw-r--r-- | tun/checksum.go | 42 | ||||
-rw-r--r-- | tun/tcp_offload_linux.go | 612 | ||||
-rw-r--r-- | tun/tcp_offload_linux_test.go | 273 | ||||
-rw-r--r-- | tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 | 8 | ||||
-rw-r--r-- | tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d | 8 | ||||
-rw-r--r-- | tun/tun_linux.go | 275 |
6 files changed, 1148 insertions, 70 deletions
diff --git a/tun/checksum.go b/tun/checksum.go new file mode 100644 index 0000000..f4f8471 --- /dev/null +++ b/tun/checksum.go @@ -0,0 +1,42 @@ +package tun + +import "encoding/binary" + +// TODO: Explore SIMD and/or other assembly optimizations. +func checksumNoFold(b []byte, initial uint64) uint64 { + ac := initial + i := 0 + n := len(b) + for n >= 4 { + ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) + n -= 4 + i += 4 + } + for n >= 2 { + ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) + n -= 2 + i += 2 + } + if n == 1 { + ac += uint64(b[i]) << 8 + } + return ac +} + +func checksum(b []byte, initial uint64) uint16 { + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go new file mode 100644 index 0000000..f3ffa75 --- /dev/null +++ b/tun/tcp_offload_linux.go @@ -0,0 +1,612 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// flowKey represents the key for a flow. +type flowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. +} + +// tcpGROTable holds flow and coalescing information for the purposes of GRO. +type tcpGROTable struct { + itemsByFlow map[flowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[flowKey][]tcpGROItem, conn.DefaultBatchSize), + itemsPool: make([][]tcpGROItem, conn.DefaultBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.DefaultBatchSize) + } + return t +} + +func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { + key := flowKey{} + addrSize := dstAddr - srcAddr + copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) + copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) ([]tcpGROItem, bool) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + buffsIndex: uint16(buffsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key flowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key flowKey + sentSeq uint32 // the sequence number + buffsIndex uint16 // the index into the original buffs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, buffs [][]byte, buffsOffset int) canCoalesce { + pktTarget := buffs[item.buffsIndex][buffsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if pkt[1] != pktTarget[1] { + // cannot coalesce with unequal ToS values + return coalesceUnavailable + } + if pkt[6]>>5 != pktTarget[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + tcpTotalLen := uint16(len(pkt) - int(iphLen)) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) + return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = 0 + coalescePSHEnding coalesceResult = 1 + coalesceItemInvalidCSum coalesceResult = 2 + coalescePktInvalidCSum coalesceResult = 3 + coalesceSuccess coalesceResult = 4 +) + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, returning the outcome. This function may swap buffs elements in the +// event of a prepend as item's buffs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, buffs [][]byte, buffsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(buffs[item.buffsIndex][buffsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-buffsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + buffs[pktBuffsIndex] = append(buffs[pktBuffsIndex], make([]byte, extendBy)...) + copy(buffs[pktBuffsIndex][buffsOffset+len(pkt):], buffs[item.buffsIndex][buffsOffset+int(headersLen):]) + // Flip the slice headers in buffs as part of prepend. The index of item + // is already being tracked for writing. + buffs[item.buffsIndex], buffs[pktBuffsIndex] = buffs[pktBuffsIndex], buffs[item.buffsIndex] + } else { + pktHead = buffs[item.buffsIndex][buffsOffset:] + if cap(pktHead)-buffsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + buffs[item.buffsIndex] = append(buffs[item.buffsIndex], make([]byte, extendBy)...) + copy(buffs[item.buffsIndex][buffsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(headersLen), + gsoSize: uint16(item.gsoSize), + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + + // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the + // (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pktHead[10], pktHead[11] = 0, 0 // clear checksum field + binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length + iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum + binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field + } + hdr.encode(buffs[item.buffsIndex][buffsOffset-virtioNetHdrLen:]) + + // Calculate the pseudo header checksum and place it at the TCP checksum + // offset. Downstream checksum offloading will combine this with computation + // of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := buffsOffset + addrOffset + srcAddr := buffs[item.buffsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := buffs[item.buffsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) + binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments = 0x80 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +// tcpGRO evaluates the TCP packet at pktI in buffs for coalescing with +// existing packets tracked in table. It will return false when pktI is not +// coalesced, otherwise true. This indicates to the caller if buffs[pktI] +// should be written to the Device. +func tcpGRO(buffs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { + pkt := buffs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return false + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return false + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return false + } + if iphLen < 20 || iphLen > 60 { + return false + } + } + if len(pkt) < iphLen { + return false + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return false + } + if len(pkt) < iphLen+tcphLen { + return false + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) { + // no GRO support for fragmented segments for now + return false + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return false + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return false + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return false + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, buffs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, buffs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return true + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return false + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return false +} + +func isTCP4(b []byte) bool { + if len(b) < 40 { + return false + } + if b[0]>>4 != 4 { + return false + } + if b[9] != unix.IPPROTO_TCP { + return false + } + return true +} + +func isTCP6NoEH(b []byte) bool { + if len(b) < 60 { + return false + } + if b[0]>>4 != 6 { + return false + } + if b[6] != unix.IPPROTO_TCP { + return false + } + return true +} + +// handleGRO evaluates buffs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. +func handleGRO(buffs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { + for i := range buffs { + if offset < virtioNetHdrLen || offset > len(buffs[i])-1 { + return errors.New("invalid offset") + } + var coalesced bool + switch { + case isTCP4(buffs[i][offset:]): + coalesced = tcpGRO(buffs, offset, i, tcp4Table, false) + case isTCP6NoEH(buffs[i][offset:]): // ipv6 packets w/extension headers do not coalesce + coalesced = tcpGRO(buffs, offset, i, tcp6Table, true) + } + if !coalesced { + hdr := virtioNetHdr{} + err := hdr.encode(buffs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + *toWrite = append(*toWrite, i) + } + } + return nil +} + +// tcpTSO splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) + in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum + firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // TCP header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // TCP checksum + tcpHLen := int(hdr.hdrLen - hdr.csumStart) + tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) + tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + return nil +} diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go new file mode 100644 index 0000000..7fa0777 --- /dev/null +++ b/tun/tcp_offload_linux_test.go @@ -0,0 +1,273 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipv4H.Encode(&header.IPv4Fields{ + SrcAddr: tcpip.Address(srcAs4[:]), + DstAddr: tcpip.Address(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + }) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipv6H.Encode(&header.IPv6Fields{ + SrcAddr: tcpip.Address(srcAs16[:]), + DstAddr: tcpip.Address(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + }) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.DefaultBatchSize) + sizes := make([]int, conn.DefaultBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 + }, + []int{0, 2, 3, 5}, + []int{240, 140, 260, 160}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + []int{0, 1}, + []int{140, 240}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + []int{0}, + []int{340}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 new file mode 100644 index 0000000..5461e79 --- /dev/null +++ b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 @@ -0,0 +1,8 @@ +go test fuzz v1 +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +int(34) diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d new file mode 100644 index 0000000..b441819 --- /dev/null +++ b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d @@ -0,0 +1,8 @@ +go test fuzz v1 +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +int(-48) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 21984ca..d56e3c1 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,9 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -33,17 +32,25 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable + toWrite []int + tcp4GROTable, tcp6GROTable *tcpGROTable } func (tun *NativeTun) File() *os.File { @@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) { return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) { - var buf []byte - if tun.nopi { - buf = buffs[0][offset:] +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcp4GROTable.reset() + tun.tcp6GROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs []error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header - buf = buffs[0][offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd + for i := range buffs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, buffsI := range tun.toWrite { + n, err := tun.tunFile.Write(buffs[buffsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = append(errs, err) } else { - buf[2] = 0x08 - buf[3] = 0x00 + total += n + } + } + return total, ErrorBatch(errs) +} + +// handleVirtioRead splits in into buffs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of buffs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(buffs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:])) } + n := copy(buffs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the TCP header length and add it onto + // csumStart, which is synonymous for IP header length. + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) } - _, err = tun.tunFile.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) } - return n, err + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return tcpTSO(in, hdr, buffs, sizes, offset) } -func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { - case err = <-tun.errors: + case err := <-tun.errors: + return 0, err default: - if tun.nopi { - sizes[0], err = tun.tunFile.Read(buffs[0][offset:]) - if err == nil { - n = 1 - } + readInto := buffs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], buffs, sizes, offset) } else { - buff := buffs[0][offset-4:] - sizes[0], err = tun.tunFile.Read(buff[:]) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 - } - if sizes[0] < 4 { - sizes[0] = 0 - } else { - sizes[0] -= 4 - } + sizes[0] = n + return 1, nil } } - return } func (tun *NativeTun) Events() <-chan Event { @@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) BatchSize() int { - return 1 + return tun.batchSize } +const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.DefaultBatchSize + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { @@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - unix.Close(nfd) - return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - unix.Close(nfd) - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } err = unix.SetNonblock(nfd, true) @@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() @@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } |