diff options
Diffstat (limited to 'pkg/server/server.go')
-rw-r--r-- | pkg/server/server.go | 234 |
1 files changed, 221 insertions, 13 deletions
diff --git a/pkg/server/server.go b/pkg/server/server.go index f1320c5c..8ac9f44a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -39,8 +39,216 @@ import ( "github.com/osrg/gobgp/pkg/packet/bgp" ) +type ListenTCP func(network string, laddr *net.TCPAddr) (*net.TCPListener, error) + +type Dial func(network, address string) (Conn, error) + +type Dialer interface { + DialContext(ctx context.Context, network, address string) (Conn, error) +} + +type Listener interface { + Accept() (Conn, error) +} + +type ListenConfig interface { + Listen(ctx context.Context, network, address string) (Listener, error) +} + +type NewDialer func(defaults *net.Dialer) Dialer + +type NewListenConfig func(defaults *net.ListenConfig) ListenConfig + +type TransportOptions struct { + Dial + ListenTCP + NewDialer + NewListenConfig +} + +type Conn interface { + net.Conn + + SetTCPTTLSockopt(ttl int) error + SetTCPMinTTLSockopt(ttl int) error +} + +type TCPConn interface { + Conn +} + +type TCPListener interface { + Listener + AcceptTCP() (TCPConn, error) + Addr() net.Addr + Close() error + SetTCPMD5SigSockopt(address string, key string) error +} + +// Golang net.TCPConn wrapper +type netTCPConn struct { + c *net.TCPConn +} + +func (c *netTCPConn) Close() error { + return c.c.Close() +} + +func (c *netTCPConn) LocalAddr() net.Addr { + return c.c.LocalAddr() +} + +func (c *netTCPConn) Read(b []byte) (int, error) { + return c.c.Read(b) +} + +func (c *netTCPConn) RemoteAddr() net.Addr { + return c.c.RemoteAddr() +} + +func (c *netTCPConn) SetDeadline(t time.Time) error { + return c.c.SetDeadline(t) +} + +func (c *netTCPConn) SetReadDeadline(t time.Time) error { + return c.c.SetReadDeadline(t) +} + +func (c *netTCPConn) SetWriteDeadline(t time.Time) error { + return c.c.SetWriteDeadline(t) +} + +func (c *netTCPConn) SetTCPTTLSockopt(ttl int) error { + return setTCPTTLSockopt(c.c, ttl) +} + +func (c *netTCPConn) SetTCPMinTTLSockopt(ttl int) error { + return setTCPMinTTLSockopt(c.c, ttl) +} + +func (c *netTCPConn) Write(b []byte) (int, error) { + return c.c.Write(b) +} + +// Golang net.TCPListener wrapper +type netTCPListener struct { + l *net.TCPListener +} + +func (l *netTCPListener) Accept() (Conn, error) { + conn, err := l.l.AcceptTCP() + if err != nil { + return nil, err + } + tconn, err := netDotConnToConn("tcp", conn) + if err != nil { + return nil, err + } + return tconn, nil +} + +func (l *netTCPListener) AcceptTCP() (TCPConn, error) { + conn, err := l.Accept () + if err != nil { + return nil, err + } + return conn.(TCPConn), nil +} + +func (l *netTCPListener) Addr() net.Addr { + return l.l.Addr() +} + +func (l *netTCPListener) Close() error { + return l.l.Close() +} + +func (l *netTCPListener) SetTCPMD5SigSockopt(address string, key string) error { + return setTCPMD5SigSockopt(l.l, address, key) +} + +func netDotConnToConn(network string, conn net.Conn) (Conn, error) { + switch network { + case "tcp": + return &netTCPConn{ + c: conn.(*net.TCPConn), + }, nil + default: + return nil, fmt.Errorf("unsupported network type %s", network) + } +} + +func netDotListenerToListener(network string, listener net.Listener) (TCPListener, error) { + log.Printf("netDotListenerToListener") + switch network { + case "tcp","tcp4","tcp6": + return &netTCPListener{ + l: listener.(*net.TCPListener), + }, nil + default: + return nil, fmt.Errorf("unsupported network type %s", network) + } +} + +var transport TransportOptions = TransportOptions{ + Dial: func(network, address string) (Conn, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + + return netDotConnToConn(network, conn) + }, + NewDialer: newNetDialer, + NewListenConfig: newNetListenConfig, +} + +type netDialer struct { + d *net.Dialer +} + +type netListenConfig struct { + lc *net.ListenConfig +} + +func (d *netDialer) DialContext(ctx context.Context, network, address string) (Conn, error) { + conn, err := d.d.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + + return netDotConnToConn(network, conn) +} + +func (lc *netListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) { + listener, err := lc.lc.Listen(ctx, network, address) + if err != nil { + return nil, err + } + + return netDotListenerToListener(network, listener) +} + +func newNetDialer(defaults *net.Dialer) Dialer { + dialer := &netDialer{ + defaults, + } + return dialer +} + +func newNetListenConfig(defaults *net.ListenConfig) ListenConfig { + lc := &netListenConfig{ + defaults, + } + return lc +} + +func SetTransportOptions(options *TransportOptions) { + transport = *options +} + type tcpListener struct { - l *net.TCPListener + l TCPListener ch chan struct{} } @@ -53,7 +261,7 @@ func (l *tcpListener) Close() error { } // avoid mapped IPv6 address -func newTCPListener(address string, port uint32, bindToDev string, ch chan *net.TCPConn) (*tcpListener, error) { +func newTCPListener(address string, port uint32, bindToDev string, ch chan TCPConn) (*tcpListener, error) { proto := "tcp4" family := syscall.AF_INET if ip := net.ParseIP(address); ip == nil { @@ -89,13 +297,13 @@ func newTCPListener(address string, port uint32, bindToDev string, ch chan *net. return nil } - l, err := lc.Listen(context.Background(), proto, addr) + l, err := transport.NewListenConfig(&lc).Listen(context.Background(), proto, addr) if err != nil { return nil, err } - listener, ok := l.(*net.TCPListener) + listener, ok := l.(TCPListener) if !ok { - err = fmt.Errorf("unexpected connection listener (not for TCP)") + err = fmt.Errorf("unexpected connection listener (not for TCP) %T", l) return nil, err } @@ -141,7 +349,7 @@ func GrpcOption(opt []grpc.ServerOption) ServerOption { type BgpServer struct { bgpConfig config.Bgp - acceptCh chan *net.TCPConn + acceptCh chan TCPConn incomings []*channels.InfiniteChannel mgmtCh chan *mgmtOp policy *table.RoutingPolicy @@ -204,8 +412,8 @@ func (s *BgpServer) delIncoming(ch *channels.InfiniteChannel) { } } -func (s *BgpServer) listListeners(addr string) []*net.TCPListener { - list := make([]*net.TCPListener, 0, len(s.listeners)) +func (s *BgpServer) listListeners(addr string) []TCPListener { + list := make([]TCPListener, 0, len(s.listeners)) rhs := net.ParseIP(addr).To4() != nil for _, l := range s.listeners { host, _, _ := net.SplitHostPort(l.l.Addr().String()) @@ -251,7 +459,7 @@ func (s *BgpServer) mgmtOperation(f func() error, checkActive bool) (err error) return } -func (s *BgpServer) passConnToPeer(conn *net.TCPConn) { +func (s *BgpServer) passConnToPeer(conn Conn) { host, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) ipaddr, _ := net.ResolveIPAddr("ip", host) remoteAddr := ipaddr.String() @@ -421,7 +629,7 @@ func (s *BgpServer) Serve() { op := value.Interface().(*mgmtOp) s.handleMGMTOp(op) case 1: - conn := value.Interface().(*net.TCPConn) + conn := value.Interface().(Conn) s.passConnToPeer(conn) case 2: ev := value.Interface().(*roaEvent) @@ -2146,7 +2354,7 @@ func (s *BgpServer) StartBgp(ctx context.Context, r *api.StartBgpRequest) error } if c.Config.Port > 0 { - acceptCh := make(chan *net.TCPConn, 4096) + acceptCh := make(chan TCPConn, 4096) for _, addr := range c.Config.LocalAddressList { l, err := newTCPListener(addr, uint32(c.Config.Port), g.BindToDevice, acceptCh) if err != nil { @@ -2934,7 +3142,7 @@ func (s *BgpServer) addNeighbor(c *config.Neighbor) error { if s.bgpConfig.Global.Config.Port > 0 { for _, l := range s.listListeners(addr) { if c.Config.AuthPassword != "" { - if err := setTCPMD5SigSockopt(l, addr, c.Config.AuthPassword); err != nil { + if err := l.SetTCPMD5SigSockopt(addr, c.Config.AuthPassword); err != nil { log.WithFields(log.Fields{ "Topic": "Peer", "Key": addr, @@ -3042,7 +3250,7 @@ func (s *BgpServer) deleteNeighbor(c *config.Neighbor, code, subcode uint8) erro return fmt.Errorf("can't delete a peer configuration for %s", addr) } for _, l := range s.listListeners(addr) { - if err := setTCPMD5SigSockopt(l, addr, ""); err != nil { + if err := l.SetTCPMD5SigSockopt(addr, ""); err != nil { log.WithFields(log.Fields{ "Topic": "Peer", "Key": addr, |