diff options
-rw-r--r-- | gobgp/cmd/common.go | 23 | ||||
-rw-r--r-- | gobgp/cmd/root.go | 4 | ||||
-rw-r--r-- | gobgpd/main.go | 32 |
3 files changed, 49 insertions, 10 deletions
diff --git a/gobgp/cmd/common.go b/gobgp/cmd/common.go index fe5a6975..a104c567 100644 --- a/gobgp/cmd/common.go +++ b/gobgp/cmd/common.go @@ -24,11 +24,14 @@ import ( "sort" "strconv" "strings" + "time" cli "github.com/osrg/gobgp/client" "github.com/osrg/gobgp/config" "github.com/osrg/gobgp/packet/bgp" "github.com/osrg/gobgp/table" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) const ( @@ -223,8 +226,26 @@ func (v vrfs) Less(i, j int) bool { } func newClient() *cli.Client { + var grpcOpts []grpc.DialOption + if globalOpts.TLS { + var creds credentials.TransportCredentials + if globalOpts.CaFile == "" { + creds = credentials.NewClientTLSFromCert(nil, "") + } else { + var err error + creds, err = credentials.NewClientTLSFromFile(globalOpts.CaFile, "") + if err != nil { + exitWithError(err) + } + } + grpcOpts = []grpc.DialOption{ + grpc.WithTimeout(time.Second), + grpc.WithBlock(), + grpc.WithTransportCredentials(creds), + } + } target := net.JoinHostPort(globalOpts.Host, strconv.Itoa(globalOpts.Port)) - client, err := cli.New(target) + client, err := cli.New(target, grpcOpts...) if err != nil { exitWithError(err) } diff --git a/gobgp/cmd/root.go b/gobgp/cmd/root.go index d3fc321b..c9281b8c 100644 --- a/gobgp/cmd/root.go +++ b/gobgp/cmd/root.go @@ -33,6 +33,8 @@ var globalOpts struct { GenCmpl bool BashCmplFile string PprofPort int + TLS bool + CaFile string } var cmds []string @@ -77,6 +79,8 @@ func NewRootCmd() *cobra.Command { rootCmd.PersistentFlags().BoolVarP(&globalOpts.GenCmpl, "gen-cmpl", "c", false, "generate completion file") rootCmd.PersistentFlags().StringVarP(&globalOpts.BashCmplFile, "bash-cmpl-file", "", "gobgp-completion.bash", "bash cmpl filename") rootCmd.PersistentFlags().IntVarP(&globalOpts.PprofPort, "pprof-port", "r", 0, "pprof port") + rootCmd.PersistentFlags().BoolVarP(&globalOpts.TLS, "tls", "", false, "connection uses TLS if true, else plain TCP") + rootCmd.PersistentFlags().StringVarP(&globalOpts.CaFile, "tls-ca-file", "", "", "The file containing the CA root cert file") globalCmd := NewGlobalCmd() neighborCmd := NewNeighborCmd() diff --git a/gobgpd/main.go b/gobgpd/main.go index 68afcb23..91e98f6c 100644 --- a/gobgpd/main.go +++ b/gobgpd/main.go @@ -16,6 +16,14 @@ package main import ( + "io/ioutil" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "runtime" + "syscall" + log "github.com/Sirupsen/logrus" "github.com/jessevdk/go-flags" p "github.com/kr/pretty" @@ -24,13 +32,8 @@ import ( "github.com/osrg/gobgp/packet/bgp" "github.com/osrg/gobgp/server" "github.com/osrg/gobgp/table" - "io/ioutil" - "net/http" - _ "net/http/pprof" - "os" - "os/signal" - "runtime" - "syscall" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) func main() { @@ -51,6 +54,9 @@ func main() { Dry bool `short:"d" long:"dry-run" description:"check configuration"` PProfHost string `long:"pprof-host" description:"specify the host that gobgpd listens on for pprof" default:"localhost:6060"` PProfDisable bool `long:"pprof-disable" description:"disable pprof profiling"` + TLS bool `long:"tls" description:"enable TLS authentication for gRPC API"` + TLSCertFile string `long:"tls-cert-file" description:"The TLS cert file"` + TLSKeyFile string `long:"tls-key-file" description:"The TLS key file"` } _, err := flags.Parse(&opts) if err != nil { @@ -118,10 +124,18 @@ func main() { bgpServer := server.NewBgpServer() go bgpServer.Serve() + var grpcOpts []grpc.ServerOption + if opts.TLS { + creds, err := credentials.NewServerTLSFromFile(opts.TLSCertFile, opts.TLSKeyFile) + if err != nil { + log.Fatalf("Failed to generate credentials: %v", err) + } + grpcOpts = []grpc.ServerOption{grpc.Creds(creds)} + } // start grpc Server - grpcServer := api.NewGrpcServer(bgpServer, opts.GrpcHosts) + apiServer := api.NewServer(bgpServer, grpc.NewServer(grpcOpts...), opts.GrpcHosts) go func() { - if err := grpcServer.Serve(); err != nil { + if err := apiServer.Serve(); err != nil { log.Fatalf("failed to listen grpc port: %s", err) } }() |