diff options
-rw-r--r-- | pkg/server/bmp.go | 6 | ||||
-rw-r--r-- | pkg/server/fsm.go | 14 | ||||
-rw-r--r-- | pkg/server/peer.go | 3 | ||||
-rw-r--r-- | pkg/server/rpki.go | 8 | ||||
-rw-r--r-- | pkg/server/server.go | 234 |
5 files changed, 236 insertions, 29 deletions
diff --git a/pkg/server/bmp.go b/pkg/server/bmp.go index 67eebd0c..828bdd56 100644 --- a/pkg/server/bmp.go +++ b/pkg/server/bmp.go @@ -81,11 +81,11 @@ func (r ribout) update(p *table.Path) bool { return true } -func (b *bmpClient) tryConnect() *net.TCPConn { +func (b *bmpClient) tryConnect() net.Conn { interval := 1 for { log.WithFields(log.Fields{"Topic": "bmp"}).Debugf("Connecting BMP server:%s", b.host) - conn, err := net.Dial("tcp", b.host) + conn, err := transport.Dial("tcp", b.host) if err != nil { select { case <-b.dead: @@ -98,7 +98,7 @@ func (b *bmpClient) tryConnect() *net.TCPConn { } } else { log.WithFields(log.Fields{"Topic": "bmp"}).Infof("BMP server is connected:%s", b.host) - return conn.(*net.TCPConn) + return conn } } } diff --git a/pkg/server/fsm.go b/pkg/server/fsm.go index 2228b7ca..d7521f4e 100644 --- a/pkg/server/fsm.go +++ b/pkg/server/fsm.go @@ -179,8 +179,8 @@ type fsm struct { outgoingCh *channels.InfiniteChannel incomingCh *channels.InfiniteChannel reason *fsmStateReason - conn net.Conn - connCh chan net.Conn + conn Conn + connCh chan Conn idleHoldTime float64 opensentHoldTime float64 adminState adminState @@ -274,7 +274,7 @@ func newFSM(gConf *config.Global, pConf *config.Neighbor) *fsm { state: bgp.BGP_FSM_IDLE, outgoingCh: channels.NewInfiniteChannel(), incomingCh: channels.NewInfiniteChannel(), - connCh: make(chan net.Conn, 1), + connCh: make(chan Conn, 1), opensentHoldTime: float64(holdtimeOpensent), adminState: adminState, adminStateCh: make(chan adminStateOperation, 1), @@ -548,13 +548,13 @@ func (h *fsmHandler) connectLoop(ctx context.Context, wg *sync.WaitGroup) { } if err == nil { - d := net.Dialer{ + d := transport.NewDialer(&net.Dialer{ LocalAddr: laddr, Timeout: time.Duration(tick-1) * time.Second, Control: func(network, address string, c syscall.RawConn) error { return dialerControl(network, address, c, ttl, ttlMin, password, bindInterface) }, - } + }) conn, err := d.DialContext(ctx, "tcp", net.JoinHostPort(addr, strconv.Itoa(port))) select { @@ -638,7 +638,7 @@ func (h *fsmHandler) active(ctx context.Context) (bgp.FSMState, *fsmStateReason) ttl = int(fsm.pConf.Transport.Config.Ttl) } if ttl != 0 { - if err := setTCPTTLSockopt(conn.(*net.TCPConn), ttl); err != nil { + if err := conn.SetTCPTTLSockopt(ttl); err != nil { log.WithFields(log.Fields{ "Topic": "Peer", "Key": fsm.pConf.Config.NeighborAddress, @@ -647,7 +647,7 @@ func (h *fsmHandler) active(ctx context.Context) (bgp.FSMState, *fsmStateReason) } } if ttlMin != 0 { - if err := setTCPMinTTLSockopt(conn.(*net.TCPConn), ttlMin); err != nil { + if err := conn.SetTCPMinTTLSockopt(ttlMin); err != nil { log.WithFields(log.Fields{ "Topic": "Peer", "Key": fsm.pConf.Config.NeighborAddress, diff --git a/pkg/server/peer.go b/pkg/server/peer.go index 243a3dcc..b8946050 100644 --- a/pkg/server/peer.go +++ b/pkg/server/peer.go @@ -17,7 +17,6 @@ package server import ( "fmt" - "net" "time" "github.com/osrg/gobgp/internal/pkg/config" @@ -559,7 +558,7 @@ func (peer *peer) StaleAll(rfList []bgp.RouteFamily) []*table.Path { return peer.adjRibIn.StaleAll(rfList) } -func (peer *peer) PassConn(conn *net.TCPConn) { +func (peer *peer) PassConn(conn Conn) { select { case peer.fsm.connCh <- conn: default: diff --git a/pkg/server/rpki.go b/pkg/server/rpki.go index 00fbfcfa..af51fa02 100644 --- a/pkg/server/rpki.go +++ b/pkg/server/rpki.go @@ -53,7 +53,7 @@ type roaEvent struct { EventType roaEventType Src string Data []byte - conn *net.TCPConn + conn net.Conn } type roaManager struct { @@ -309,7 +309,7 @@ func (m *roaManager) GetServers() []*config.RpkiServer { type roaClient struct { host string - conn *net.TCPConn + conn net.Conn state config.RpkiServerState eventCh chan *roaEvent sessionID uint16 @@ -383,14 +383,14 @@ func (c *roaClient) tryConnect() { return default: } - if conn, err := net.Dial("tcp", c.host); err != nil { + if conn, err := transport.Dial("tcp", c.host); err != nil { // better to use context with timeout time.Sleep(connectRetryInterval * time.Second) } else { c.eventCh <- &roaEvent{ EventType: roaConnected, Src: c.host, - conn: conn.(*net.TCPConn), + conn: conn, } return } diff --git a/pkg/server/server.go b/pkg/server/server.go index f1320c5c..8ac9f44a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -39,8 +39,216 @@ import ( "github.com/osrg/gobgp/pkg/packet/bgp" ) +type ListenTCP func(network string, laddr *net.TCPAddr) (*net.TCPListener, error) + +type Dial func(network, address string) (Conn, error) + +type Dialer interface { + DialContext(ctx context.Context, network, address string) (Conn, error) +} + +type Listener interface { + Accept() (Conn, error) +} + +type ListenConfig interface { + Listen(ctx context.Context, network, address string) (Listener, error) +} + +type NewDialer func(defaults *net.Dialer) Dialer + +type NewListenConfig func(defaults *net.ListenConfig) ListenConfig + +type TransportOptions struct { + Dial + ListenTCP + NewDialer + NewListenConfig +} + +type Conn interface { + net.Conn + + SetTCPTTLSockopt(ttl int) error + SetTCPMinTTLSockopt(ttl int) error +} + +type TCPConn interface { + Conn +} + +type TCPListener interface { + Listener + AcceptTCP() (TCPConn, error) + Addr() net.Addr + Close() error + SetTCPMD5SigSockopt(address string, key string) error +} + +// Golang net.TCPConn wrapper +type netTCPConn struct { + c *net.TCPConn +} + +func (c *netTCPConn) Close() error { + return c.c.Close() +} + +func (c *netTCPConn) LocalAddr() net.Addr { + return c.c.LocalAddr() +} + +func (c *netTCPConn) Read(b []byte) (int, error) { + return c.c.Read(b) +} + +func (c *netTCPConn) RemoteAddr() net.Addr { + return c.c.RemoteAddr() +} + +func (c *netTCPConn) SetDeadline(t time.Time) error { + return c.c.SetDeadline(t) +} + +func (c *netTCPConn) SetReadDeadline(t time.Time) error { + return c.c.SetReadDeadline(t) +} + +func (c *netTCPConn) SetWriteDeadline(t time.Time) error { + return c.c.SetWriteDeadline(t) +} + +func (c *netTCPConn) SetTCPTTLSockopt(ttl int) error { + return setTCPTTLSockopt(c.c, ttl) +} + +func (c *netTCPConn) SetTCPMinTTLSockopt(ttl int) error { + return setTCPMinTTLSockopt(c.c, ttl) +} + +func (c *netTCPConn) Write(b []byte) (int, error) { + return c.c.Write(b) +} + +// Golang net.TCPListener wrapper +type netTCPListener struct { + l *net.TCPListener +} + +func (l *netTCPListener) Accept() (Conn, error) { + conn, err := l.l.AcceptTCP() + if err != nil { + return nil, err + } + tconn, err := netDotConnToConn("tcp", conn) + if err != nil { + return nil, err + } + return tconn, nil +} + +func (l *netTCPListener) AcceptTCP() (TCPConn, error) { + conn, err := l.Accept () + if err != nil { + return nil, err + } + return conn.(TCPConn), nil +} + +func (l *netTCPListener) Addr() net.Addr { + return l.l.Addr() +} + +func (l *netTCPListener) Close() error { + return l.l.Close() +} + +func (l *netTCPListener) SetTCPMD5SigSockopt(address string, key string) error { + return setTCPMD5SigSockopt(l.l, address, key) +} + +func netDotConnToConn(network string, conn net.Conn) (Conn, error) { + switch network { + case "tcp": + return &netTCPConn{ + c: conn.(*net.TCPConn), + }, nil + default: + return nil, fmt.Errorf("unsupported network type %s", network) + } +} + +func netDotListenerToListener(network string, listener net.Listener) (TCPListener, error) { + log.Printf("netDotListenerToListener") + switch network { + case "tcp","tcp4","tcp6": + return &netTCPListener{ + l: listener.(*net.TCPListener), + }, nil + default: + return nil, fmt.Errorf("unsupported network type %s", network) + } +} + +var transport TransportOptions = TransportOptions{ + Dial: func(network, address string) (Conn, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + + return netDotConnToConn(network, conn) + }, + NewDialer: newNetDialer, + NewListenConfig: newNetListenConfig, +} + +type netDialer struct { + d *net.Dialer +} + +type netListenConfig struct { + lc *net.ListenConfig +} + +func (d *netDialer) DialContext(ctx context.Context, network, address string) (Conn, error) { + conn, err := d.d.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + + return netDotConnToConn(network, conn) +} + +func (lc *netListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) { + listener, err := lc.lc.Listen(ctx, network, address) + if err != nil { + return nil, err + } + + return netDotListenerToListener(network, listener) +} + +func newNetDialer(defaults *net.Dialer) Dialer { + dialer := &netDialer{ + defaults, + } + return dialer +} + +func newNetListenConfig(defaults *net.ListenConfig) ListenConfig { + lc := &netListenConfig{ + defaults, + } + return lc +} + +func SetTransportOptions(options *TransportOptions) { + transport = *options +} + type tcpListener struct { - l *net.TCPListener + l TCPListener ch chan struct{} } @@ -53,7 +261,7 @@ func (l *tcpListener) Close() error { } // avoid mapped IPv6 address -func newTCPListener(address string, port uint32, bindToDev string, ch chan *net.TCPConn) (*tcpListener, error) { +func newTCPListener(address string, port uint32, bindToDev string, ch chan TCPConn) (*tcpListener, error) { proto := "tcp4" family := syscall.AF_INET if ip := net.ParseIP(address); ip == nil { @@ -89,13 +297,13 @@ func newTCPListener(address string, port uint32, bindToDev string, ch chan *net. return nil } - l, err := lc.Listen(context.Background(), proto, addr) + l, err := transport.NewListenConfig(&lc).Listen(context.Background(), proto, addr) if err != nil { return nil, err } - listener, ok := l.(*net.TCPListener) + listener, ok := l.(TCPListener) if !ok { - err = fmt.Errorf("unexpected connection listener (not for TCP)") + err = fmt.Errorf("unexpected connection listener (not for TCP) %T", l) return nil, err } @@ -141,7 +349,7 @@ func GrpcOption(opt []grpc.ServerOption) ServerOption { type BgpServer struct { bgpConfig config.Bgp - acceptCh chan *net.TCPConn + acceptCh chan TCPConn incomings []*channels.InfiniteChannel mgmtCh chan *mgmtOp policy *table.RoutingPolicy @@ -204,8 +412,8 @@ func (s *BgpServer) delIncoming(ch *channels.InfiniteChannel) { } } -func (s *BgpServer) listListeners(addr string) []*net.TCPListener { - list := make([]*net.TCPListener, 0, len(s.listeners)) +func (s *BgpServer) listListeners(addr string) []TCPListener { + list := make([]TCPListener, 0, len(s.listeners)) rhs := net.ParseIP(addr).To4() != nil for _, l := range s.listeners { host, _, _ := net.SplitHostPort(l.l.Addr().String()) @@ -251,7 +459,7 @@ func (s *BgpServer) mgmtOperation(f func() error, checkActive bool) (err error) return } -func (s *BgpServer) passConnToPeer(conn *net.TCPConn) { +func (s *BgpServer) passConnToPeer(conn Conn) { host, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) ipaddr, _ := net.ResolveIPAddr("ip", host) remoteAddr := ipaddr.String() @@ -421,7 +629,7 @@ func (s *BgpServer) Serve() { op := value.Interface().(*mgmtOp) s.handleMGMTOp(op) case 1: - conn := value.Interface().(*net.TCPConn) + conn := value.Interface().(Conn) s.passConnToPeer(conn) case 2: ev := value.Interface().(*roaEvent) @@ -2146,7 +2354,7 @@ func (s *BgpServer) StartBgp(ctx context.Context, r *api.StartBgpRequest) error } if c.Config.Port > 0 { - acceptCh := make(chan *net.TCPConn, 4096) + acceptCh := make(chan TCPConn, 4096) for _, addr := range c.Config.LocalAddressList { l, err := newTCPListener(addr, uint32(c.Config.Port), g.BindToDevice, acceptCh) if err != nil { @@ -2934,7 +3142,7 @@ func (s *BgpServer) addNeighbor(c *config.Neighbor) error { if s.bgpConfig.Global.Config.Port > 0 { for _, l := range s.listListeners(addr) { if c.Config.AuthPassword != "" { - if err := setTCPMD5SigSockopt(l, addr, c.Config.AuthPassword); err != nil { + if err := l.SetTCPMD5SigSockopt(addr, c.Config.AuthPassword); err != nil { log.WithFields(log.Fields{ "Topic": "Peer", "Key": addr, @@ -3042,7 +3250,7 @@ func (s *BgpServer) deleteNeighbor(c *config.Neighbor, code, subcode uint8) erro return fmt.Errorf("can't delete a peer configuration for %s", addr) } for _, l := range s.listListeners(addr) { - if err := setTCPMD5SigSockopt(l, addr, ""); err != nil { + if err := l.SetTCPMD5SigSockopt(addr, ""); err != nil { log.WithFields(log.Fields{ "Topic": "Peer", "Key": addr, |