diff options
Diffstat (limited to 'server/server.go')
-rw-r--r-- | server/server.go | 94 |
1 files changed, 58 insertions, 36 deletions
diff --git a/server/server.go b/server/server.go index 846931c4..09d3fe31 100644 --- a/server/server.go +++ b/server/server.go @@ -87,6 +87,59 @@ func (ws Watchers) watching(typ watcherEventType) bool { return false } +type TCPListener struct { + l *net.TCPListener + ch chan struct{} +} + +func (l *TCPListener) Close() error { + if err := l.l.Close(); err != nil { + return err + } + t := time.NewTicker(time.Second) + select { + case <-l.ch: + case <-t.C: + return fmt.Errorf("close timeout") + } + return nil +} + +// avoid mapped IPv6 address +func NewTCPListener(address string, port uint32, ch chan *net.TCPConn) (*TCPListener, error) { + proto := "tcp4" + if ip := net.ParseIP(address); ip == nil { + return nil, fmt.Errorf("can't listen on %s", address) + } else if ip.To4() == nil { + proto = "tcp6" + } + addr, err := net.ResolveTCPAddr(proto, net.JoinHostPort(address, strconv.Itoa(int(port)))) + if err != nil { + return nil, err + } + + l, err := net.ListenTCP(proto, addr) + if err != nil { + return nil, err + } + closeCh := make(chan struct{}) + go func() error { + for { + conn, err := l.AcceptTCP() + if err != nil { + close(closeCh) + log.Warn(err) + return err + } + ch <- conn + } + }() + return &TCPListener{ + l: l, + ch: closeCh, + }, nil +} + type BgpServer struct { bgpConfig config.Bgp addedPeerCh chan config.Neighbor @@ -103,7 +156,7 @@ type BgpServer struct { policy *table.RoutingPolicy broadcastReqs []*GrpcRequest broadcastMsgs []broadcastMsg - listeners []*net.TCPListener + listeners []*TCPListener neighborMap map[string]*Peer globalRib *table.TableManager zclient *zebra.Client @@ -127,37 +180,6 @@ func NewBgpServer() *BgpServer { return &b } -// avoid mapped IPv6 address -func listenAndAccept(address string, port uint32, ch chan *net.TCPConn) (*net.TCPListener, error) { - proto := "tcp4" - if ip := net.ParseIP(address); ip == nil { - return nil, fmt.Errorf("can't listen on %s", address) - } else if ip.To4() == nil { - proto = "tcp6" - } - addr, err := net.ResolveTCPAddr(proto, net.JoinHostPort(address, strconv.Itoa(int(port)))) - if err != nil { - return nil, err - } - - l, err := net.ListenTCP(proto, addr) - if err != nil { - return nil, err - } - go func() error { - for { - conn, err := l.AcceptTCP() - if err != nil { - log.Error(err) - return err - } - ch <- conn - } - }() - - return l, nil -} - func (server *BgpServer) notify2watchers(typ watcherEventType, ev watcherEvent) error { for _, watcher := range server.watchers { if ch := watcher.notify(typ); ch != nil { @@ -174,10 +196,10 @@ func (server *BgpServer) Listeners(addr string) []*net.TCPListener { list := make([]*net.TCPListener, 0, len(server.listeners)) rhs := net.ParseIP(addr).To4() != nil for _, l := range server.listeners { - host, _, _ := net.SplitHostPort(l.Addr().String()) + host, _, _ := net.SplitHostPort(l.l.Addr().String()) lhs := net.ParseIP(host).To4() != nil if lhs == rhs { - list = append(list, l) + list = append(list, l.l) } } return list @@ -222,7 +244,7 @@ func (server *BgpServer) Serve() { } }(broadcastCh) - server.listeners = make([]*net.TCPListener, 0, 2) + server.listeners = make([]*TCPListener, 0, 2) server.fsmincomingCh = channels.NewInfiniteChannel() server.fsmStateCh = make(chan *FsmMsg, 4096) var senderMsgs []*SenderMsg @@ -1698,7 +1720,7 @@ func (server *BgpServer) handleModConfig(grpcReq *GrpcRequest) error { if c.ListenConfig.Port > 0 { acceptCh := make(chan *net.TCPConn, 4096) for _, addr := range c.ListenConfig.LocalAddressList { - l, err := listenAndAccept(addr, uint32(c.ListenConfig.Port), acceptCh) + l, err := NewTCPListener(addr, uint32(c.ListenConfig.Port), acceptCh) if err != nil { return err } |