summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--gobgp/cmd/common.go23
-rw-r--r--gobgp/cmd/root.go4
-rw-r--r--gobgpd/main.go32
3 files changed, 49 insertions, 10 deletions
diff --git a/gobgp/cmd/common.go b/gobgp/cmd/common.go
index fe5a6975..a104c567 100644
--- a/gobgp/cmd/common.go
+++ b/gobgp/cmd/common.go
@@ -24,11 +24,14 @@ import (
"sort"
"strconv"
"strings"
+ "time"
cli "github.com/osrg/gobgp/client"
"github.com/osrg/gobgp/config"
"github.com/osrg/gobgp/packet/bgp"
"github.com/osrg/gobgp/table"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
)
const (
@@ -223,8 +226,26 @@ func (v vrfs) Less(i, j int) bool {
}
func newClient() *cli.Client {
+ var grpcOpts []grpc.DialOption
+ if globalOpts.TLS {
+ var creds credentials.TransportCredentials
+ if globalOpts.CaFile == "" {
+ creds = credentials.NewClientTLSFromCert(nil, "")
+ } else {
+ var err error
+ creds, err = credentials.NewClientTLSFromFile(globalOpts.CaFile, "")
+ if err != nil {
+ exitWithError(err)
+ }
+ }
+ grpcOpts = []grpc.DialOption{
+ grpc.WithTimeout(time.Second),
+ grpc.WithBlock(),
+ grpc.WithTransportCredentials(creds),
+ }
+ }
target := net.JoinHostPort(globalOpts.Host, strconv.Itoa(globalOpts.Port))
- client, err := cli.New(target)
+ client, err := cli.New(target, grpcOpts...)
if err != nil {
exitWithError(err)
}
diff --git a/gobgp/cmd/root.go b/gobgp/cmd/root.go
index d3fc321b..c9281b8c 100644
--- a/gobgp/cmd/root.go
+++ b/gobgp/cmd/root.go
@@ -33,6 +33,8 @@ var globalOpts struct {
GenCmpl bool
BashCmplFile string
PprofPort int
+ TLS bool
+ CaFile string
}
var cmds []string
@@ -77,6 +79,8 @@ func NewRootCmd() *cobra.Command {
rootCmd.PersistentFlags().BoolVarP(&globalOpts.GenCmpl, "gen-cmpl", "c", false, "generate completion file")
rootCmd.PersistentFlags().StringVarP(&globalOpts.BashCmplFile, "bash-cmpl-file", "", "gobgp-completion.bash", "bash cmpl filename")
rootCmd.PersistentFlags().IntVarP(&globalOpts.PprofPort, "pprof-port", "r", 0, "pprof port")
+ rootCmd.PersistentFlags().BoolVarP(&globalOpts.TLS, "tls", "", false, "connection uses TLS if true, else plain TCP")
+ rootCmd.PersistentFlags().StringVarP(&globalOpts.CaFile, "tls-ca-file", "", "", "The file containing the CA root cert file")
globalCmd := NewGlobalCmd()
neighborCmd := NewNeighborCmd()
diff --git a/gobgpd/main.go b/gobgpd/main.go
index 68afcb23..91e98f6c 100644
--- a/gobgpd/main.go
+++ b/gobgpd/main.go
@@ -16,6 +16,14 @@
package main
import (
+ "io/ioutil"
+ "net/http"
+ _ "net/http/pprof"
+ "os"
+ "os/signal"
+ "runtime"
+ "syscall"
+
log "github.com/Sirupsen/logrus"
"github.com/jessevdk/go-flags"
p "github.com/kr/pretty"
@@ -24,13 +32,8 @@ import (
"github.com/osrg/gobgp/packet/bgp"
"github.com/osrg/gobgp/server"
"github.com/osrg/gobgp/table"
- "io/ioutil"
- "net/http"
- _ "net/http/pprof"
- "os"
- "os/signal"
- "runtime"
- "syscall"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
)
func main() {
@@ -51,6 +54,9 @@ func main() {
Dry bool `short:"d" long:"dry-run" description:"check configuration"`
PProfHost string `long:"pprof-host" description:"specify the host that gobgpd listens on for pprof" default:"localhost:6060"`
PProfDisable bool `long:"pprof-disable" description:"disable pprof profiling"`
+ TLS bool `long:"tls" description:"enable TLS authentication for gRPC API"`
+ TLSCertFile string `long:"tls-cert-file" description:"The TLS cert file"`
+ TLSKeyFile string `long:"tls-key-file" description:"The TLS key file"`
}
_, err := flags.Parse(&opts)
if err != nil {
@@ -118,10 +124,18 @@ func main() {
bgpServer := server.NewBgpServer()
go bgpServer.Serve()
+ var grpcOpts []grpc.ServerOption
+ if opts.TLS {
+ creds, err := credentials.NewServerTLSFromFile(opts.TLSCertFile, opts.TLSKeyFile)
+ if err != nil {
+ log.Fatalf("Failed to generate credentials: %v", err)
+ }
+ grpcOpts = []grpc.ServerOption{grpc.Creds(creds)}
+ }
// start grpc Server
- grpcServer := api.NewGrpcServer(bgpServer, opts.GrpcHosts)
+ apiServer := api.NewServer(bgpServer, grpc.NewServer(grpcOpts...), opts.GrpcHosts)
go func() {
- if err := grpcServer.Serve(); err != nil {
+ if err := apiServer.Serve(); err != nil {
log.Fatalf("failed to listen grpc port: %s", err)
}
}()