summaryrefslogtreecommitdiffhomepage
path: root/pkg/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/server.go')
-rw-r--r--pkg/server/server.go234
1 files changed, 221 insertions, 13 deletions
diff --git a/pkg/server/server.go b/pkg/server/server.go
index f1320c5c..8ac9f44a 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -39,8 +39,216 @@ import (
"github.com/osrg/gobgp/pkg/packet/bgp"
)
+type ListenTCP func(network string, laddr *net.TCPAddr) (*net.TCPListener, error)
+
+type Dial func(network, address string) (Conn, error)
+
+type Dialer interface {
+ DialContext(ctx context.Context, network, address string) (Conn, error)
+}
+
+type Listener interface {
+ Accept() (Conn, error)
+}
+
+type ListenConfig interface {
+ Listen(ctx context.Context, network, address string) (Listener, error)
+}
+
+type NewDialer func(defaults *net.Dialer) Dialer
+
+type NewListenConfig func(defaults *net.ListenConfig) ListenConfig
+
+type TransportOptions struct {
+ Dial
+ ListenTCP
+ NewDialer
+ NewListenConfig
+}
+
+type Conn interface {
+ net.Conn
+
+ SetTCPTTLSockopt(ttl int) error
+ SetTCPMinTTLSockopt(ttl int) error
+}
+
+type TCPConn interface {
+ Conn
+}
+
+type TCPListener interface {
+ Listener
+ AcceptTCP() (TCPConn, error)
+ Addr() net.Addr
+ Close() error
+ SetTCPMD5SigSockopt(address string, key string) error
+}
+
+// Golang net.TCPConn wrapper
+type netTCPConn struct {
+ c *net.TCPConn
+}
+
+func (c *netTCPConn) Close() error {
+ return c.c.Close()
+}
+
+func (c *netTCPConn) LocalAddr() net.Addr {
+ return c.c.LocalAddr()
+}
+
+func (c *netTCPConn) Read(b []byte) (int, error) {
+ return c.c.Read(b)
+}
+
+func (c *netTCPConn) RemoteAddr() net.Addr {
+ return c.c.RemoteAddr()
+}
+
+func (c *netTCPConn) SetDeadline(t time.Time) error {
+ return c.c.SetDeadline(t)
+}
+
+func (c *netTCPConn) SetReadDeadline(t time.Time) error {
+ return c.c.SetReadDeadline(t)
+}
+
+func (c *netTCPConn) SetWriteDeadline(t time.Time) error {
+ return c.c.SetWriteDeadline(t)
+}
+
+func (c *netTCPConn) SetTCPTTLSockopt(ttl int) error {
+ return setTCPTTLSockopt(c.c, ttl)
+}
+
+func (c *netTCPConn) SetTCPMinTTLSockopt(ttl int) error {
+ return setTCPMinTTLSockopt(c.c, ttl)
+}
+
+func (c *netTCPConn) Write(b []byte) (int, error) {
+ return c.c.Write(b)
+}
+
+// Golang net.TCPListener wrapper
+type netTCPListener struct {
+ l *net.TCPListener
+}
+
+func (l *netTCPListener) Accept() (Conn, error) {
+ conn, err := l.l.AcceptTCP()
+ if err != nil {
+ return nil, err
+ }
+ tconn, err := netDotConnToConn("tcp", conn)
+ if err != nil {
+ return nil, err
+ }
+ return tconn, nil
+}
+
+func (l *netTCPListener) AcceptTCP() (TCPConn, error) {
+ conn, err := l.Accept ()
+ if err != nil {
+ return nil, err
+ }
+ return conn.(TCPConn), nil
+}
+
+func (l *netTCPListener) Addr() net.Addr {
+ return l.l.Addr()
+}
+
+func (l *netTCPListener) Close() error {
+ return l.l.Close()
+}
+
+func (l *netTCPListener) SetTCPMD5SigSockopt(address string, key string) error {
+ return setTCPMD5SigSockopt(l.l, address, key)
+}
+
+func netDotConnToConn(network string, conn net.Conn) (Conn, error) {
+ switch network {
+ case "tcp":
+ return &netTCPConn{
+ c: conn.(*net.TCPConn),
+ }, nil
+ default:
+ return nil, fmt.Errorf("unsupported network type %s", network)
+ }
+}
+
+func netDotListenerToListener(network string, listener net.Listener) (TCPListener, error) {
+ log.Printf("netDotListenerToListener")
+ switch network {
+ case "tcp","tcp4","tcp6":
+ return &netTCPListener{
+ l: listener.(*net.TCPListener),
+ }, nil
+ default:
+ return nil, fmt.Errorf("unsupported network type %s", network)
+ }
+}
+
+var transport TransportOptions = TransportOptions{
+ Dial: func(network, address string) (Conn, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ return netDotConnToConn(network, conn)
+ },
+ NewDialer: newNetDialer,
+ NewListenConfig: newNetListenConfig,
+}
+
+type netDialer struct {
+ d *net.Dialer
+}
+
+type netListenConfig struct {
+ lc *net.ListenConfig
+}
+
+func (d *netDialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
+ conn, err := d.d.DialContext(ctx, network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ return netDotConnToConn(network, conn)
+}
+
+func (lc *netListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
+ listener, err := lc.lc.Listen(ctx, network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ return netDotListenerToListener(network, listener)
+}
+
+func newNetDialer(defaults *net.Dialer) Dialer {
+ dialer := &netDialer{
+ defaults,
+ }
+ return dialer
+}
+
+func newNetListenConfig(defaults *net.ListenConfig) ListenConfig {
+ lc := &netListenConfig{
+ defaults,
+ }
+ return lc
+}
+
+func SetTransportOptions(options *TransportOptions) {
+ transport = *options
+}
+
type tcpListener struct {
- l *net.TCPListener
+ l TCPListener
ch chan struct{}
}
@@ -53,7 +261,7 @@ func (l *tcpListener) Close() error {
}
// avoid mapped IPv6 address
-func newTCPListener(address string, port uint32, bindToDev string, ch chan *net.TCPConn) (*tcpListener, error) {
+func newTCPListener(address string, port uint32, bindToDev string, ch chan TCPConn) (*tcpListener, error) {
proto := "tcp4"
family := syscall.AF_INET
if ip := net.ParseIP(address); ip == nil {
@@ -89,13 +297,13 @@ func newTCPListener(address string, port uint32, bindToDev string, ch chan *net.
return nil
}
- l, err := lc.Listen(context.Background(), proto, addr)
+ l, err := transport.NewListenConfig(&lc).Listen(context.Background(), proto, addr)
if err != nil {
return nil, err
}
- listener, ok := l.(*net.TCPListener)
+ listener, ok := l.(TCPListener)
if !ok {
- err = fmt.Errorf("unexpected connection listener (not for TCP)")
+ err = fmt.Errorf("unexpected connection listener (not for TCP) %T", l)
return nil, err
}
@@ -141,7 +349,7 @@ func GrpcOption(opt []grpc.ServerOption) ServerOption {
type BgpServer struct {
bgpConfig config.Bgp
- acceptCh chan *net.TCPConn
+ acceptCh chan TCPConn
incomings []*channels.InfiniteChannel
mgmtCh chan *mgmtOp
policy *table.RoutingPolicy
@@ -204,8 +412,8 @@ func (s *BgpServer) delIncoming(ch *channels.InfiniteChannel) {
}
}
-func (s *BgpServer) listListeners(addr string) []*net.TCPListener {
- list := make([]*net.TCPListener, 0, len(s.listeners))
+func (s *BgpServer) listListeners(addr string) []TCPListener {
+ list := make([]TCPListener, 0, len(s.listeners))
rhs := net.ParseIP(addr).To4() != nil
for _, l := range s.listeners {
host, _, _ := net.SplitHostPort(l.l.Addr().String())
@@ -251,7 +459,7 @@ func (s *BgpServer) mgmtOperation(f func() error, checkActive bool) (err error)
return
}
-func (s *BgpServer) passConnToPeer(conn *net.TCPConn) {
+func (s *BgpServer) passConnToPeer(conn Conn) {
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
ipaddr, _ := net.ResolveIPAddr("ip", host)
remoteAddr := ipaddr.String()
@@ -421,7 +629,7 @@ func (s *BgpServer) Serve() {
op := value.Interface().(*mgmtOp)
s.handleMGMTOp(op)
case 1:
- conn := value.Interface().(*net.TCPConn)
+ conn := value.Interface().(Conn)
s.passConnToPeer(conn)
case 2:
ev := value.Interface().(*roaEvent)
@@ -2146,7 +2354,7 @@ func (s *BgpServer) StartBgp(ctx context.Context, r *api.StartBgpRequest) error
}
if c.Config.Port > 0 {
- acceptCh := make(chan *net.TCPConn, 4096)
+ acceptCh := make(chan TCPConn, 4096)
for _, addr := range c.Config.LocalAddressList {
l, err := newTCPListener(addr, uint32(c.Config.Port), g.BindToDevice, acceptCh)
if err != nil {
@@ -2934,7 +3142,7 @@ func (s *BgpServer) addNeighbor(c *config.Neighbor) error {
if s.bgpConfig.Global.Config.Port > 0 {
for _, l := range s.listListeners(addr) {
if c.Config.AuthPassword != "" {
- if err := setTCPMD5SigSockopt(l, addr, c.Config.AuthPassword); err != nil {
+ if err := l.SetTCPMD5SigSockopt(addr, c.Config.AuthPassword); err != nil {
log.WithFields(log.Fields{
"Topic": "Peer",
"Key": addr,
@@ -3042,7 +3250,7 @@ func (s *BgpServer) deleteNeighbor(c *config.Neighbor, code, subcode uint8) erro
return fmt.Errorf("can't delete a peer configuration for %s", addr)
}
for _, l := range s.listListeners(addr) {
- if err := setTCPMD5SigSockopt(l, addr, ""); err != nil {
+ if err := l.SetTCPMD5SigSockopt(addr, ""); err != nil {
log.WithFields(log.Fields{
"Topic": "Peer",
"Key": addr,