diff options
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r-- | conn/bind_std.go | 339 |
1 files changed, 239 insertions, 100 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go index 98fe23c..a164f56 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -6,32 +6,91 @@ package conn import ( + "context" "errors" "net" "net/netip" + "strconv" "sync" "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) ) -// StdNetBind is meant to be a temporary solution on platforms for which -// the sticky socket / source caching behavior has not yet been implemented. -// It uses the Go's net package to implement networking. -// See LinuxSocketBind for a proper implementation on the Linux platform. +// StdNetBind implements Bind for all platforms except Windows. type StdNetBind struct { - mu sync.Mutex // protects following fields - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool + mu sync.Mutex // protects following fields + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + batchSize int + udpAddrPool sync.Pool + ipv4MsgsPool sync.Pool + ipv6MsgsPool sync.Pool } -func NewStdNetBind() Bind { return &StdNetBind{} } +func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) } + +func NewStdNetBindBatch(maxBatchSize int) Bind { + if maxBatchSize == 0 { + maxBatchSize = DefaultBatchSize + } + return &StdNetBind{ + batchSize: maxBatchSize, + + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, -type StdNetEndpoint netip.AddrPort + ipv4MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv4.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + + ipv6MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if supported. + src struct { + netip.Addr + ifidx int32 + } +} var ( _ Bind = (*StdNetBind)(nil) - _ Endpoint = StdNetEndpoint{} + _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { @@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { return asEndpoint(e), err } -func (StdNetEndpoint) ClearSrc() {} +func (e *StdNetEndpoint) ClearSrc() { + e.src.ifidx = 0 + e.src.Addr = netip.Addr{} +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} -func (e StdNetEndpoint) DstIP() netip.Addr { - return (netip.AddrPort)(e).Addr() +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return e.src.Addr } -func (e StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} // not supported +func (e *StdNetEndpoint) SrcIfidx() int32 { + return e.src.ifidx } -func (e StdNetEndpoint) DstToBytes() []byte { - b, _ := (netip.AddrPort)(e).MarshalBinary() +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() return b } -func (e StdNetEndpoint) DstToString() string { - return (netip.AddrPort)(e).String() +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() } -func (e StdNetEndpoint) SrcToString() string { - return "" +func (e *StdNetEndpoint) SrcToString() string { + return e.src.Addr.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { if err != nil { return nil, 0, err } - return conn, uaddr.Port, nil + return conn.(*net.UDPConn), uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() var err error var tries int - if bind.ipv4 != nil || bind.ipv6 != nil { + if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } @@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { // If uport is 0, we can retry on failure. again: port := int(uport) - var ipv4, ipv6 *net.UDPConn + var v4conn, v6conn *net.UDPConn - ipv4, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - ipv6, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - ipv4.Close() + v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - ipv4.Close() + v4conn.Close() return nil, 0, err } var fns []ReceiveFunc - if ipv4 != nil { - fns = append(fns, bind.makeReceiveIPv4(ipv4)) - bind.ipv4 = ipv4 + if v4conn != nil { + fns = append(fns, s.receiveIPv4) + s.ipv4 = v4conn } - if ipv6 != nil { - fns = append(fns, bind.makeReceiveIPv6(ipv6)) - bind.ipv6 = ipv6 + if v6conn != nil { + fns = append(fns, s.receiveIPv6) + s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } - return fns, uint16(port), nil -} -func (bind *StdNetBind) BatchSize() int { - return 1 -} + s.ipv4PC = ipv4.NewPacketConn(s.ipv4) + s.ipv6PC = ipv6.NewPacketConn(s.ipv6) -func (bind *StdNetBind) Close() error { - bind.mu.Lock() - defer bind.mu.Unlock() + return fns, uint16(port), nil +} - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - bind.ipv4 = nil +func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - bind.ipv6 = nil + numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err } - bind.blackhole4 = false - bind.blackhole6 = false - if err1 != nil { - return err1 + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep } - return err2 + return numMsgs, nil } -func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } +func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + defer s.ipv6MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0) + if err != nil { return 0, err } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } -func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } - return 0, err - } +func (s *StdNetBind) BatchSize() int { + return s.batchSize } -func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { - var err error - nend, ok := endpoint.(StdNetEndpoint) - if !ok { - return ErrWrongEndpointType +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err1, err2 error + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + } + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + } + s.blackhole4 = false + s.blackhole6 = false + if err1 != nil { + return err1 } - addrPort := netip.AddrPort(nend) + return err2 +} - bind.mu.Lock() - blackhole := bind.blackhole4 - conn := bind.ipv4 - if addrPort.Addr().Is6() { - blackhole = bind.blackhole6 - conn = bind.ipv6 +func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + is6 = true } - bind.mu.Unlock() + s.mu.Unlock() if blackhole { return nil @@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - for _, buff := range buffs { - _, err = conn.WriteToUDPAddrPort(buff, addrPort) - if err != nil { - return err + if is6 { + return s.send6(s.ipv6PC, endpoint, buffs) + } else { + return s.send4(s.ipv4PC, endpoint, buffs) + } +} + +func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as4 := ep.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv4MsgsPool.Put(msgs) + return err +} + +func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as16 := ep.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break } + start += n } - return nil + s.udpAddrPool.Put(ua) + s.ipv6MsgsPool.Put(msgs) + return err } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. @@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { // but Endpoints are immutable, so we can re-use them. var endpointPool = sync.Pool{ New: func() any { - return make(map[netip.AddrPort]Endpoint) + return make(map[netip.AddrPort]*StdNetEndpoint) }, } // asEndpoint returns an Endpoint containing ap. -func asEndpoint(ap netip.AddrPort) Endpoint { - m := endpointPool.Get().(map[netip.AddrPort]Endpoint) +func asEndpoint(ap netip.AddrPort) *StdNetEndpoint { + m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint) defer endpointPool.Put(m) e, ok := m[ap] if !ok { - e = Endpoint(StdNetEndpoint(ap)) + e = &StdNetEndpoint{AddrPort: ap} m[ap] = e } return e |