summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--conn/bind_linux.go587
-rw-r--r--conn/bind_std.go339
-rw-r--r--conn/boundif_android.go8
-rw-r--r--conn/conn.go2
-rw-r--r--conn/controlfns.go36
-rw-r--r--conn/controlfns_linux.go41
-rw-r--r--conn/controlfns_unix.go28
-rw-r--r--conn/default.go2
-rw-r--r--conn/mark_default.go2
-rw-r--r--conn/mark_unix.go10
-rw-r--r--conn/sticky_default.go26
-rw-r--r--conn/sticky_linux.go111
-rw-r--r--conn/sticky_linux_test.go207
-rw-r--r--device/queueconstants_android.go4
-rw-r--r--device/queueconstants_default.go4
-rw-r--r--device/sticky_linux.go14
-rw-r--r--go.mod6
-rw-r--r--go.sum12
-rw-r--r--tun/checksum.go42
-rw-r--r--tun/tcp_offload_linux.go612
-rw-r--r--tun/tcp_offload_linux_test.go273
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b778
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d8
-rw-r--r--tun/tun_linux.go275
24 files changed, 1870 insertions, 787 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
deleted file mode 100644
index b6bc0dc..0000000
--- a/conn/bind_linux.go
+++ /dev/null
@@ -1,587 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "net/netip"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-type ipv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type ipv6Source struct {
- src [16]byte
- // ifindex belongs in dst.ZoneId
-}
-
-type LinuxSocketEndpoint struct {
- mu sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(ipv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
-func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
- return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
- return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
-type LinuxSocketBind struct {
- // mu guards sock4 and sock6 and the associated fds.
- // As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
-}
-
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
-
-var (
- _ Endpoint = (*LinuxSocketEndpoint)(nil)
- _ Bind = (*LinuxSocketBind)(nil)
-)
-
-func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
- var end LinuxSocketEndpoint
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
-
- if e.Addr().Is4() {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = int(e.Port())
- dst.Addr = e.Addr().As4()
- end.ClearSrc()
- return &end, nil
- }
-
- if e.Addr().Is6() {
- zone, err := zoneToUint32(e.Addr().Zone())
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = int(e.Port())
- dst.ZoneId = zone
- dst.Addr = e.Addr().As16()
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("invalid IP address")
-}
-
-func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
-
- var err error
- var newPort uint16
- var tries int
-
- if bind.sock4 != -1 || bind.sock6 != -1 {
- return nil, 0, ErrBindAlreadyOpen
- }
-
- originalPort := port
-
-again:
- port = originalPort
- var sock4, sock6 int
- // Attempt ipv6 bind, update port if successful.
- sock6, newPort, err = create6(port)
- if err != nil {
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- sock4, newPort, err = create4(port)
- if err != nil {
- if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- unix.Close(sock6)
- tries++
- goto again
- }
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- unix.Close(sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- var fns []ReceiveFunc
- if sock4 != -1 {
- bind.sock4 = sock4
- fns = append(fns, bind.receiveIPv4)
- }
- if sock6 != -1 {
- bind.sock6 = sock6
- fns = append(fns, bind.receiveIPv6)
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, port, nil
-}
-
-func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
-
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (bind *LinuxSocketBind) BatchSize() int {
- return 1
-}
-
-func (bind *LinuxSocketBind) Close() error {
- // Take a readlock to shut down the sockets...
- bind.mu.RLock()
- if bind.sock6 != -1 {
- unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
- }
- if bind.sock4 != -1 {
- unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
- }
- bind.mu.RUnlock()
- // ...and a write lock to close the fd.
- // This ensures that no one else is using the fd.
- bind.mu.Lock()
- defer bind.mu.Unlock()
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = unix.Close(bind.sock6)
- bind.sock6 = -1
- }
- if bind.sock4 != -1 {
- err2 = unix.Close(bind.sock4)
- bind.sock4 = -1
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock4 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock6 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
- nend, ok := end.(*LinuxSocketEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send4(bind.sock4, nend, buff)
- if err != nil {
- return err
- }
- }
- } else {
- if bind.sock6 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send6(bind.sock6, nend, buff)
- if err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.src4().Src)
- } else {
- return netip.AddrFrom16(end.src6().src)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.dst4().Addr)
- } else {
- return netip.AddrFrom16(end.dst6().Addr)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *LinuxSocketEndpoint) DstToString() string {
- var port int
- if !end.isV6 {
- port = end.dst4().Port
- } else {
- port = end.dst6().Port
- }
- return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
-}
-
-func (end *LinuxSocketEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *LinuxSocketEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 98fe23c..a164f56 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -6,32 +6,91 @@
package conn
import (
+ "context"
"errors"
"net"
"net/netip"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+// StdNetBind implements Bind for all platforms except Windows.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
+ mu sync.Mutex // protects following fields
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+ ipv4PC *ipv4.PacketConn
+ ipv6PC *ipv6.PacketConn
+ batchSize int
+ udpAddrPool sync.Pool
+ ipv4MsgsPool sync.Pool
+ ipv6MsgsPool sync.Pool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) }
+
+func NewStdNetBindBatch(maxBatchSize int) Bind {
+ if maxBatchSize == 0 {
+ maxBatchSize = DefaultBatchSize
+ }
+ return &StdNetBind{
+ batchSize: maxBatchSize,
+
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
-type StdNetEndpoint netip.AddrPort
+ ipv4MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv4.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+
+ ipv6MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv6.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if supported.
+ src struct {
+ netip.Addr
+ ifidx int32
+ }
+}
var (
_ Bind = (*StdNetBind)(nil)
- _ Endpoint = StdNetEndpoint{}
+ _ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
@@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
return asEndpoint(e), err
}
-func (StdNetEndpoint) ClearSrc() {}
+func (e *StdNetEndpoint) ClearSrc() {
+ e.src.ifidx = 0
+ e.src.Addr = netip.Addr{}
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
+}
-func (e StdNetEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return e.src.Addr
}
-func (e StdNetEndpoint) SrcIP() netip.Addr {
- return netip.Addr{} // not supported
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return e.src.ifidx
}
-func (e StdNetEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
return b
}
-func (e StdNetEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
}
-func (e StdNetEndpoint) SrcToString() string {
- return ""
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.src.Addr.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ fns = append(fns, s.receiveIPv4)
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ fns = append(fns, s.receiveIPv6)
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
- return fns, uint16(port), nil
-}
-func (bind *StdNetBind) BatchSize() int {
- return 1
-}
+ s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
+ s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+ return fns, uint16(port), nil
+}
- var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ defer s.ipv4MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
}
- bind.blackhole4 = false
- bind.blackhole6 = false
- if err1 != nil {
- return err1
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
}
- return err2
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
+func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ defer s.ipv6MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
+ }
+ numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
+ if err != nil {
return 0, err
}
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
- return 0, err
- }
+func (s *StdNetBind) BatchSize() int {
+ return s.batchSize
}
-func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err1, err2 error
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ }
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ }
+ s.blackhole4 = false
+ s.blackhole6 = false
+ if err1 != nil {
+ return err1
}
- addrPort := netip.AddrPort(nend)
+ return err2
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if addrPort.Addr().Is6() {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ is6 = true
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- for _, buff := range buffs {
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
- if err != nil {
- return err
+ if is6 {
+ return s.send6(s.ipv6PC, endpoint, buffs)
+ } else {
+ return s.send4(s.ipv4PC, endpoint, buffs)
+ }
+}
+
+func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as4 := ep.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
+ }
+ start += n
+ }
+ s.udpAddrPool.Put(ua)
+ s.ipv4MsgsPool.Put(msgs)
+ return err
+}
+
+func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as16 := ep.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
}
+ start += n
}
- return nil
+ s.udpAddrPool.Put(ua)
+ s.ipv6MsgsPool.Put(msgs)
+ return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
- return make(map[netip.AddrPort]Endpoint)
+ return make(map[netip.AddrPort]*StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
-func asEndpoint(ap netip.AddrPort) Endpoint {
- m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
+func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
+ m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
- e = Endpoint(StdNetEndpoint(ap))
+ e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index 818e4e6..dd3ca5b 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -5,8 +5,8 @@
package conn
-func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
-func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
diff --git a/conn/conn.go b/conn/conn.go
index 8c0a827..9cbd0af 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -16,7 +16,7 @@ import (
)
const (
- DefaultBatchSize = 1 // maximum number of packets handled per read and write
+ DefaultBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..fe32871
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..9e26d95
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,41 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..9738c73
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,28 @@
+//go:build !windows && !linux && !js
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
index c7b4a84..b6f761b 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -1,4 +1,4 @@
-//go:build !linux && !windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
diff --git a/conn/mark_default.go b/conn/mark_default.go
index 9944c38..3102384 100644
--- a/conn/mark_default.go
+++ b/conn/mark_default.go
@@ -7,6 +7,6 @@
package conn
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
index 5566b28..d9e46ee 100644
--- a/conn/mark_unix.go
+++ b/conn/mark_unix.go
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..3ce9a56
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,26 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
+// use alternatively named flags and need ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// srcControlSize returns the recommended buffer size for pooling sticky control
+// data.
+const srcControlSize = 0
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..bf17839
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,111 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
+ ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
+ ep.src.ifidx = info.Ifindex
+
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
+ ep.src.Addr = netip.AddrFrom16(info.Addr)
+ ep.src.ifidx = int32(info.Ifindex)
+
+ return
+ }
+ }
+}
+
+// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
+// panics if buf is of insufficient size.
+func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
+ size := int(unsafe.Sizeof(t))
+ if len(buf) < size {
+ panic("pktInfoFromBuf: buffer too small")
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
+ return t
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ *control = (*control)[:cap(*control)]
+ if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
+ *control = (*control)[:0]
+ return
+ }
+
+ if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
+ *control = (*control)[:0]
+ return
+ }
+
+ if len(*control) < srcControlSize {
+ *control = (*control)[:0]
+ return
+ }
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
+ if ep.SrcIP().Is4() {
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+ info.Ifindex = ep.src.ifidx
+ if ep.SrcIP().IsValid() {
+ info.Spec_dst = ep.SrcIP().As4()
+ }
+ } else {
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
+
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+ info.Ifindex = uint32(ep.src.ifidx)
+ if ep.SrcIP().IsValid() {
+ info.Addr = ep.SrcIP().As16()
+ }
+ }
+
+ *control = (*control)[:hdr.Len]
+}
+
+var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..a42c89e
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,207 @@
+//go:build linux
+// +build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ ep.src.Addr = netip.MustParseAddr("127.0.0.1")
+ ep.src.ifidx = 5
+
+ control := make([]byte, srcControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ ep.src.Addr = netip.MustParseAddr("::1")
+ ep.src.ifidx = 5
+
+ control := make([]byte, srcControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.src.Addr)
+ }
+ if ep.src.ifidx != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.src.ifidx != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ ep := &StdNetEndpoint{}
+ ep.src.Addr = netip.MustParseAddr("::1")
+ ep.src.ifidx = 5
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.src.Addr)
+ }
+ if ep.src.ifidx != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index fc937b3..1158387 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -5,10 +5,12 @@
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.DefaultBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index 6b69150..7ed70a1 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -7,8 +7,10 @@
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.DefaultBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index 7afdf28..3ce0769 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -25,7 +25,7 @@ import (
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
- if _, ok := bind.(*conn.LinuxSocketBind); !ok {
+ if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
@@ -112,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
pePtr.peer.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
+ if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
+ pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
@@ -136,12 +136,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
peer.RUnlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
+ nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
if nativeEP == nil {
peer.RUnlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock()
break
}
@@ -169,12 +169,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
diff --git a/go.mod b/go.mod
index 46427df..a6f7254 100644
--- a/go.mod
+++ b/go.mod
@@ -3,9 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.19
require (
- golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
- golang.org/x/net v0.0.0-20220225172249-27dd8689420f
- golang.org/x/sys v0.2.0
+ golang.org/x/crypto v0.3.0
+ golang.org/x/net v0.2.0
+ golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
)
diff --git a/go.sum b/go.sum
index 97ae95d..a967a10 100644
--- a/go.sum
+++ b/go.sum
@@ -1,11 +1,11 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
-golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
-golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
-golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
-golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A=
+golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
+golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
+golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
+golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
+golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..f4f8471
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,42 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+ i := 0
+ n := len(b)
+ for n >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
+ n -= 4
+ i += 4
+ }
+ for n >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
+ n -= 2
+ i += 2
+ }
+ if n == 1 {
+ ac += uint64(b[i]) << 8
+ }
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go
new file mode 100644
index 0000000..f3ffa75
--- /dev/null
+++ b/tun/tcp_offload_linux.go
@@ -0,0 +1,612 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// flowKey represents the key for a flow.
+type flowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of GRO.
+type tcpGROTable struct {
+ itemsByFlow map[flowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[flowKey][]tcpGROItem, conn.DefaultBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.DefaultBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.DefaultBatchSize)
+ }
+ return t
+}
+
+func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
+ key := flowKey{}
+ addrSize := dstAddr - srcAddr
+ copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
+ copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) ([]tcpGROItem, bool) {
+ key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) {
+ key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ buffsIndex: uint16(buffsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key flowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key flowKey
+ sentSeq uint32 // the sequence number
+ buffsIndex uint16 // the index into the original buffs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, buffs [][]byte, buffsOffset int) canCoalesce {
+ pktTarget := buffs[item.buffsIndex][buffsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if pkt[1] != pktTarget[1] {
+ // cannot coalesce with unequal ToS values
+ return coalesceUnavailable
+ }
+ if pkt[6]>>5 != pktTarget[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ tcpTotalLen := uint16(len(pkt) - int(iphLen))
+ tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
+ return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = 0
+ coalescePSHEnding coalesceResult = 1
+ coalesceItemInvalidCSum coalesceResult = 2
+ coalescePktInvalidCSum coalesceResult = 3
+ coalesceSuccess coalesceResult = 4
+)
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, returning the outcome. This function may swap buffs elements in the
+// event of a prepend as item's buffs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, buffs [][]byte, buffsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(buffs[item.buffsIndex][buffsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-buffsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ buffs[pktBuffsIndex] = append(buffs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(buffs[pktBuffsIndex][buffsOffset+len(pkt):], buffs[item.buffsIndex][buffsOffset+int(headersLen):])
+ // Flip the slice headers in buffs as part of prepend. The index of item
+ // is already being tracked for writing.
+ buffs[item.buffsIndex], buffs[pktBuffsIndex] = buffs[pktBuffsIndex], buffs[item.buffsIndex]
+ } else {
+ pktHead = buffs[item.buffsIndex][buffsOffset:]
+ if cap(pktHead)-buffsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ buffs[item.buffsIndex] = append(buffs[item.buffsIndex], make([]byte, extendBy)...)
+ copy(buffs[item.buffsIndex][buffsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(headersLen),
+ gsoSize: uint16(item.gsoSize),
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+
+ // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
+ // (IPv4) header checksum.
+ if isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pktHead[10], pktHead[11] = 0, 0 // clear checksum field
+ binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
+ iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
+ binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
+ }
+ hdr.encode(buffs[item.buffsIndex][buffsOffset-virtioNetHdrLen:])
+
+ // Calculate the pseudo header checksum and place it at the TCP checksum
+ // offset. Downstream checksum offloading will combine this with computation
+ // of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := buffsOffset + addrOffset
+ srcAddr := buffs[item.buffsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := buffs[item.buffsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments = 0x80
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+// tcpGRO evaluates the TCP packet at pktI in buffs for coalescing with
+// existing packets tracked in table. It will return false when pktI is not
+// coalesced, otherwise true. This indicates to the caller if buffs[pktI]
+// should be written to the Device.
+func tcpGRO(buffs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
+ pkt := buffs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return false
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return false
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return false
+ }
+ if iphLen < 20 || iphLen > 60 {
+ return false
+ }
+ }
+ if len(pkt) < iphLen {
+ return false
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return false
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return false
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) {
+ // no GRO support for fragmented segments for now
+ return false
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return false
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return false
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return false
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, buffs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, buffs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return true
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return false
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return false
+}
+
+func isTCP4(b []byte) bool {
+ if len(b) < 40 {
+ return false
+ }
+ if b[0]>>4 != 4 {
+ return false
+ }
+ if b[9] != unix.IPPROTO_TCP {
+ return false
+ }
+ return true
+}
+
+func isTCP6NoEH(b []byte) bool {
+ if len(b) < 60 {
+ return false
+ }
+ if b[0]>>4 != 6 {
+ return false
+ }
+ if b[6] != unix.IPPROTO_TCP {
+ return false
+ }
+ return true
+}
+
+// handleGRO evaluates buffs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets.
+func handleGRO(buffs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
+ for i := range buffs {
+ if offset < virtioNetHdrLen || offset > len(buffs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var coalesced bool
+ switch {
+ case isTCP4(buffs[i][offset:]):
+ coalesced = tcpGRO(buffs, offset, i, tcp4Table, false)
+ case isTCP6NoEH(buffs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
+ coalesced = tcpGRO(buffs, offset, i, tcp6Table, true)
+ }
+ if !coalesced {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(buffs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ return nil
+}
+
+// tcpTSO splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
+ firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // TCP header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // TCP checksum
+ tcpHLen := int(hdr.hdrLen - hdr.csumStart)
+ tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
+ tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
+ tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go
new file mode 100644
index 0000000..7fa0777
--- /dev/null
+++ b/tun/tcp_offload_linux_test.go
@@ -0,0 +1,273 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipv4H.Encode(&header.IPv4Fields{
+ SrcAddr: tcpip.Address(srcAs4[:]),
+ DstAddr: tcpip.Address(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ })
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipv6H.Encode(&header.IPv6Fields{
+ SrcAddr: tcpip.Address(srcAs16[:]),
+ DstAddr: tcpip.Address(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ })
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.DefaultBatchSize)
+ sizes := make([]int, conn.DefaultBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
+ },
+ []int{0, 2, 3, 5},
+ []int{240, 140, 260, 160},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ []int{0, 1},
+ []int{140, 240},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ []int{0},
+ []int{340},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
new file mode 100644
index 0000000..5461e79
--- /dev/null
+++ b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
@@ -0,0 +1,8 @@
+go test fuzz v1
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+int(34)
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
new file mode 100644
index 0000000..b441819
--- /dev/null
+++ b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
@@ -0,0 +1,8 @@
+go test fuzz v1
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+int(-48)
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 21984ca..d56e3c1 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -17,9 +17,8 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
-
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -33,17 +32,25 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
+ toWrite []int
+ tcp4GROTable, tcp6GROTable *tcpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
- var buf []byte
- if tun.nopi {
- buf = buffs[0][offset:]
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcp4GROTable.reset()
+ tun.tcp6GROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs []error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
- buf = buffs[0][offset-4:]
-
- // add packet information header
- buf[0] = 0x00
- buf[1] = 0x00
- if buf[4]>>4 == ipv6.Version {
- buf[2] = 0x86
- buf[3] = 0xdd
+ for i := range buffs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, buffsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(buffs[buffsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = append(errs, err)
} else {
- buf[2] = 0x08
- buf[3] = 0x00
+ total += n
+ }
+ }
+ return total, ErrorBatch(errs)
+}
+
+// handleVirtioRead splits in into buffs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of buffs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(buffs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:]))
}
+ n := copy(buffs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
+ }
+
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
+
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the TCP header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
+
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
- _, err = tun.tunFile.Write(buf)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- } else if err == nil {
- n = 1
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
- return n, err
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
+
+ return tcpTSO(in, hdr, buffs, sizes, offset)
}
-func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
- case err = <-tun.errors:
+ case err := <-tun.errors:
+ return 0, err
default:
- if tun.nopi {
- sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
- if err == nil {
- n = 1
- }
+ readInto := buffs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], buffs, sizes, offset)
} else {
- buff := buffs[0][offset-4:]
- sizes[0], err = tun.tunFile.Read(buff[:])
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- } else if err == nil {
- n = 1
- }
- if sizes[0] < 4 {
- sizes[0] = 0
- } else {
- sizes[0] -= 4
- }
+ sizes[0] = n
+ return 1, nil
}
}
- return
}
func (tun *NativeTun) Events() <-chan Event {
@@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) BatchSize() int {
- return 1
+ return tun.batchSize
}
+const (
+ // TODO: support TSO with ECN bits
+ tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.DefaultBatchSize
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
@@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- unix.Close(nfd)
- return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- unix.Close(nfd)
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
err = unix.SetNonblock(nfd, true)
@@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcp4GROTable: newTCPGROTable(),
+ tcp6GROTable: newTCPGROTable(),
+ toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
@@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
+ // start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcp4GROTable: newTCPGROTable(),
+ tcp6GROTable: newTCPGROTable(),
+ toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}