diff options
Diffstat (limited to 'conn')
-rw-r--r-- | conn/bind_std.go | 192 | ||||
-rw-r--r-- | conn/bind_std_test.go | 22 |
2 files changed, 150 insertions, 64 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go index b9da4c3..a842b12 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -10,6 +10,7 @@ import ( "errors" "net" "net/netip" + "runtime" "strconv" "sync" "syscall" @@ -22,16 +23,21 @@ var ( _ Bind = (*StdNetBind)(nil) ) -// StdNetBind implements Bind for all platforms except Windows. +// StdNetBind implements Bind for all platforms. While Windows has its own Bind +// (see bind_windows.go), it may fall back to StdNetBind. +// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable +// methods for sending and receiving multiple datagrams per-syscall. See the +// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects following fields - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool - ipv4PC *ipv4.PacketConn - ipv6PC *ipv6.PacketConn - udpAddrPool sync.Pool + mu sync.Mutex // protects following fields + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + + udpAddrPool sync.Pool // following fields are not guarded by mu ipv4MsgsPool sync.Pool ipv6MsgsPool sync.Pool } @@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { again: port := int(uport) var v4conn, v6conn *net.UDPConn + var v4pc *ipv4.PacketConn + var v6pc *ipv6.PacketConn v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { @@ -173,63 +181,92 @@ again: } var fns []ReceiveFunc if v4conn != nil { - fns = append(fns, s.receiveIPv4) + if runtime.GOOS == "linux" { + v4pc = ipv4.NewPacketConn(v4conn) + s.ipv4PC = v4pc + } + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) s.ipv4 = v4conn } if v6conn != nil { - fns = append(fns, s.receiveIPv6) + if runtime.GOOS == "linux" { + v6pc = ipv6.NewPacketConn(v6conn) + s.ipv6PC = v6pc + } + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } - s.ipv4PC = ipv4.NewPacketConn(s.ipv4) - s.ipv6PC = ipv6.NewPacketConn(s.ipv6) - return fns, uint16(port), nil } -func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range buffs { - (*msgs)[i].Buffers[0] = buffs[i] - } - numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := asEndpoint(addrPort) - getSrcFromControl(msg.OOB, ep) - eps[i] = ep +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { + return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + var numMsgs int + if runtime.GOOS == "linux" { + numMsgs, err = pc.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } - return numMsgs, nil } -func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range buffs { - (*msgs)[i].Buffers[0] = buffs[i] - } - numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := asEndpoint(addrPort) - getSrcFromControl(msg.OOB, ep) - eps[i] = ep +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { + return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + var numMsgs int + if runtime.GOOS == "linux" { + numMsgs, err = pc.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } - return numMsgs, nil } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and @@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error { if s.ipv4 != nil { err1 = s.ipv4.Close() s.ipv4 = nil + s.ipv4PC = nil } if s.ipv6 != nil { err2 = s.ipv6.Close() s.ipv6 = nil + s.ipv6PC = nil } s.blackhole4 = false s.blackhole6 = false @@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 + var ( + pc4 *ipv4.PacketConn + pc6 *ipv6.PacketConn + ) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 + pc6 = s.ipv6PC is6 = true + } else { + pc4 = s.ipv4PC } s.mu.Unlock() @@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { return syscall.EAFNOSUPPORT } if is6 { - return s.send6(s.ipv6PC, endpoint, buffs) + return s.send6(conn, pc6, endpoint, buffs) } else { - return s.send4(s.ipv4PC, endpoint, buffs) + return s.send4(conn, pc4, endpoint, buffs) } } -func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { +func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { ua := s.udpAddrPool.Get().(*net.UDPAddr) as4 := ep.DstIP().As4() copy(ua.IP, as4[:]) @@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e err error start int ) - for { - n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) - if err != nil || n == len((*msgs)[start:len(buffs)]) { - break + if runtime.GOOS == "linux" { + for { + n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + } else { + for i, buff := range buffs { + _, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua) + if err != nil { + break + } } - start += n } s.udpAddrPool.Put(ua) s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { +func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { ua := s.udpAddrPool.Get().(*net.UDPAddr) as16 := ep.DstIP().As16() copy(ua.IP, as16[:]) @@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e err error start int ) - for { - n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) - if err != nil || n == len((*msgs)[start:len(buffs)]) { - break + if runtime.GOOS == "linux" { + for { + n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + } else { + for i, buff := range buffs { + _, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua) + if err != nil { + break + } } - start += n } s.udpAddrPool.Put(ua) s.ipv6MsgsPool.Put(msgs) diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go new file mode 100644 index 0000000..76afa30 --- /dev/null +++ b/conn/bind_std_test.go @@ -0,0 +1,22 @@ +package conn + +import "testing" + +func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { + bind := NewStdNetBind().(*StdNetBind) + fns, _, err := bind.Open(0) + if err != nil { + t.Fatal(err) + } + bind.Close() + buffs := make([][]byte, 1) + buffs[0] = make([]byte, 1) + sizes := make([]int, 1) + eps := make([]Endpoint, 1) + for _, fn := range fns { + // The ReceiveFuncs must not access conn-related fields on StdNetBind + // unguarded. Close() nils the conn-related fields resulting in a panic + // if they violate the mutex. + fn(buffs, sizes, eps) + } +} |