diff options
Diffstat (limited to 'tun/tun_linux.go')
-rw-r--r-- | tun/tun_linux.go | 275 |
1 files changed, 205 insertions, 70 deletions
diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 21984ca..d56e3c1 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,9 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -33,17 +32,25 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable + toWrite []int + tcp4GROTable, tcp6GROTable *tcpGROTable } func (tun *NativeTun) File() *os.File { @@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) { return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) { - var buf []byte - if tun.nopi { - buf = buffs[0][offset:] +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcp4GROTable.reset() + tun.tcp6GROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs []error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header - buf = buffs[0][offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd + for i := range buffs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, buffsI := range tun.toWrite { + n, err := tun.tunFile.Write(buffs[buffsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = append(errs, err) } else { - buf[2] = 0x08 - buf[3] = 0x00 + total += n + } + } + return total, ErrorBatch(errs) +} + +// handleVirtioRead splits in into buffs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of buffs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(buffs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:])) } + n := copy(buffs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the TCP header length and add it onto + // csumStart, which is synonymous for IP header length. + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) } - _, err = tun.tunFile.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) } - return n, err + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return tcpTSO(in, hdr, buffs, sizes, offset) } -func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { - case err = <-tun.errors: + case err := <-tun.errors: + return 0, err default: - if tun.nopi { - sizes[0], err = tun.tunFile.Read(buffs[0][offset:]) - if err == nil { - n = 1 - } + readInto := buffs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], buffs, sizes, offset) } else { - buff := buffs[0][offset-4:] - sizes[0], err = tun.tunFile.Read(buff[:]) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 - } - if sizes[0] < 4 { - sizes[0] = 0 - } else { - sizes[0] -= 4 - } + sizes[0] = n + return 1, nil } } - return } func (tun *NativeTun) Events() <-chan Event { @@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) BatchSize() int { - return 1 + return tun.batchSize } +const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.DefaultBatchSize + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { @@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - unix.Close(nfd) - return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - unix.Close(nfd) - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } err = unix.SetNonblock(nfd, true) @@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() @@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } |