From 9e2f3860220280a5630971478b53c8ad9a991ca8 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 2 Mar 2023 15:08:28 -0800 Subject: conn, device, tun: implement vectorized I/O on Linux Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker Signed-off-by: James Tucker Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- tun/tun_linux.go | 275 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 205 insertions(+), 70 deletions(-) (limited to 'tun/tun_linux.go') diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 21984ca..d56e3c1 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,9 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -33,17 +32,25 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable + toWrite []int + tcp4GROTable, tcp6GROTable *tcpGROTable } func (tun *NativeTun) File() *os.File { @@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) { return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) { - var buf []byte - if tun.nopi { - buf = buffs[0][offset:] +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcp4GROTable.reset() + tun.tcp6GROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs []error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header - buf = buffs[0][offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd + for i := range buffs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, buffsI := range tun.toWrite { + n, err := tun.tunFile.Write(buffs[buffsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = append(errs, err) } else { - buf[2] = 0x08 - buf[3] = 0x00 + total += n + } + } + return total, ErrorBatch(errs) +} + +// handleVirtioRead splits in into buffs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of buffs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(buffs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:])) } + n := copy(buffs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the TCP header length and add it onto + // csumStart, which is synonymous for IP header length. + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) } - _, err = tun.tunFile.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) } - return n, err + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return tcpTSO(in, hdr, buffs, sizes, offset) } -func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { - case err = <-tun.errors: + case err := <-tun.errors: + return 0, err default: - if tun.nopi { - sizes[0], err = tun.tunFile.Read(buffs[0][offset:]) - if err == nil { - n = 1 - } + readInto := buffs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], buffs, sizes, offset) } else { - buff := buffs[0][offset-4:] - sizes[0], err = tun.tunFile.Read(buff[:]) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 - } - if sizes[0] < 4 { - sizes[0] = 0 - } else { - sizes[0] -= 4 - } + sizes[0] = n + return 1, nil } } - return } func (tun *NativeTun) Events() <-chan Event { @@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) BatchSize() int { - return 1 + return tun.batchSize } +const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.DefaultBatchSize + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { @@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - unix.Close(nfd) - return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - unix.Close(nfd) - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } err = unix.SetNonblock(nfd, true) @@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() @@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } -- cgit v1.2.3