diff options
author | Jordan Whited <jordan@tailscale.com> | 2023-03-02 15:08:28 -0800 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2023-03-10 14:52:17 +0100 |
commit | 9e2f3860220280a5630971478b53c8ad9a991ca8 (patch) | |
tree | 218f1bd9a8dd649a8fdb50571a921d1ccff4cae5 /conn/bind_std.go | |
parent | 3bb8fec7e41fcc2138ddb4cba3f46100814fc523 (diff) |
conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.
Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.
Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.
Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r-- | conn/bind_std.go | 339 |
1 files changed, 239 insertions, 100 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go index 98fe23c..a164f56 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -6,32 +6,91 @@ package conn import ( + "context" "errors" "net" "net/netip" + "strconv" "sync" "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) ) -// StdNetBind is meant to be a temporary solution on platforms for which -// the sticky socket / source caching behavior has not yet been implemented. -// It uses the Go's net package to implement networking. -// See LinuxSocketBind for a proper implementation on the Linux platform. +// StdNetBind implements Bind for all platforms except Windows. type StdNetBind struct { - mu sync.Mutex // protects following fields - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool + mu sync.Mutex // protects following fields + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + batchSize int + udpAddrPool sync.Pool + ipv4MsgsPool sync.Pool + ipv6MsgsPool sync.Pool } -func NewStdNetBind() Bind { return &StdNetBind{} } +func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) } + +func NewStdNetBindBatch(maxBatchSize int) Bind { + if maxBatchSize == 0 { + maxBatchSize = DefaultBatchSize + } + return &StdNetBind{ + batchSize: maxBatchSize, + + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, -type StdNetEndpoint netip.AddrPort + ipv4MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv4.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + + ipv6MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if supported. + src struct { + netip.Addr + ifidx int32 + } +} var ( _ Bind = (*StdNetBind)(nil) - _ Endpoint = StdNetEndpoint{} + _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { @@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { return asEndpoint(e), err } -func (StdNetEndpoint) ClearSrc() {} +func (e *StdNetEndpoint) ClearSrc() { + e.src.ifidx = 0 + e.src.Addr = netip.Addr{} +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} -func (e StdNetEndpoint) DstIP() netip.Addr { - return (netip.AddrPort)(e).Addr() +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return e.src.Addr } -func (e StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} // not supported +func (e *StdNetEndpoint) SrcIfidx() int32 { + return e.src.ifidx } -func (e StdNetEndpoint) DstToBytes() []byte { - b, _ := (netip.AddrPort)(e).MarshalBinary() +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() return b } -func (e StdNetEndpoint) DstToString() string { - return (netip.AddrPort)(e).String() +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() } -func (e StdNetEndpoint) SrcToString() string { - return "" +func (e *StdNetEndpoint) SrcToString() string { + return e.src.Addr.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { if err != nil { return nil, 0, err } - return conn, uaddr.Port, nil + return conn.(*net.UDPConn), uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() var err error var tries int - if bind.ipv4 != nil || bind.ipv6 != nil { + if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } @@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { // If uport is 0, we can retry on failure. again: port := int(uport) - var ipv4, ipv6 *net.UDPConn + var v4conn, v6conn *net.UDPConn - ipv4, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - ipv6, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - ipv4.Close() + v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - ipv4.Close() + v4conn.Close() return nil, 0, err } var fns []ReceiveFunc - if ipv4 != nil { - fns = append(fns, bind.makeReceiveIPv4(ipv4)) - bind.ipv4 = ipv4 + if v4conn != nil { + fns = append(fns, s.receiveIPv4) + s.ipv4 = v4conn } - if ipv6 != nil { - fns = append(fns, bind.makeReceiveIPv6(ipv6)) - bind.ipv6 = ipv6 + if v6conn != nil { + fns = append(fns, s.receiveIPv6) + s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } - return fns, uint16(port), nil -} -func (bind *StdNetBind) BatchSize() int { - return 1 -} + s.ipv4PC = ipv4.NewPacketConn(s.ipv4) + s.ipv6PC = ipv6.NewPacketConn(s.ipv6) -func (bind *StdNetBind) Close() error { - bind.mu.Lock() - defer bind.mu.Unlock() + return fns, uint16(port), nil +} - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - bind.ipv4 = nil +func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - bind.ipv6 = nil + numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err } - bind.blackhole4 = false - bind.blackhole6 = false - if err1 != nil { - return err1 + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep } - return err2 + return numMsgs, nil } -func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } +func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + defer s.ipv6MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0) + if err != nil { return 0, err } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } -func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } - return 0, err - } +func (s *StdNetBind) BatchSize() int { + return s.batchSize } -func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { - var err error - nend, ok := endpoint.(StdNetEndpoint) - if !ok { - return ErrWrongEndpointType +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err1, err2 error + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + } + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + } + s.blackhole4 = false + s.blackhole6 = false + if err1 != nil { + return err1 } - addrPort := netip.AddrPort(nend) + return err2 +} - bind.mu.Lock() - blackhole := bind.blackhole4 - conn := bind.ipv4 - if addrPort.Addr().Is6() { - blackhole = bind.blackhole6 - conn = bind.ipv6 +func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + is6 = true } - bind.mu.Unlock() + s.mu.Unlock() if blackhole { return nil @@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - for _, buff := range buffs { - _, err = conn.WriteToUDPAddrPort(buff, addrPort) - if err != nil { - return err + if is6 { + return s.send6(s.ipv6PC, endpoint, buffs) + } else { + return s.send4(s.ipv4PC, endpoint, buffs) + } +} + +func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as4 := ep.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv4MsgsPool.Put(msgs) + return err +} + +func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as16 := ep.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break } + start += n } - return nil + s.udpAddrPool.Put(ua) + s.ipv6MsgsPool.Put(msgs) + return err } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. @@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { // but Endpoints are immutable, so we can re-use them. var endpointPool = sync.Pool{ New: func() any { - return make(map[netip.AddrPort]Endpoint) + return make(map[netip.AddrPort]*StdNetEndpoint) }, } // asEndpoint returns an Endpoint containing ap. -func asEndpoint(ap netip.AddrPort) Endpoint { - m := endpointPool.Get().(map[netip.AddrPort]Endpoint) +func asEndpoint(ap netip.AddrPort) *StdNetEndpoint { + m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint) defer endpointPool.Put(m) e, ok := m[ap] if !ok { - e = Endpoint(StdNetEndpoint(ap)) + e = &StdNetEndpoint{AddrPort: ap} m[ap] = e } return e |