diff options
author | bbassingthwaite <bbassingthwaite@digitalocean.com> | 2021-03-16 08:44:34 -0600 |
---|---|---|
committer | bbassingthwaite <bbassingthwaite@digitalocean.com> | 2021-03-16 08:44:34 -0600 |
commit | dfc8ec6437ca3b41d32e7c79239f3e56363147ed (patch) | |
tree | da5487c31b1b379c9ba29b1507e7660023e5b585 | |
parent | 390e3127cea7a4b3fb9fcc089cfc7ffd02f108e1 (diff) |
Add support for the gRPC server to listen on a unix domain socket
Fixes #2230
-rw-r--r-- | pkg/server/grpc_server.go | 11 | ||||
-rw-r--r-- | pkg/server/grpc_server_test.go | 43 |
2 files changed, 53 insertions, 1 deletions
diff --git a/pkg/server/grpc_server.go b/pkg/server/grpc_server.go index 89bbb497..2d25734a 100644 --- a/pkg/server/grpc_server.go +++ b/pkg/server/grpc_server.go @@ -64,8 +64,9 @@ func (s *server) serve() error { l := []net.Listener{} var err error for _, host := range strings.Split(s.hosts, ",") { + network, address := parseHost(host) var lis net.Listener - lis, err = net.Listen("tcp", host) + lis, err = net.Listen(network, address) if err != nil { log.WithFields(log.Fields{ "Topic": "grpc", @@ -101,6 +102,14 @@ func (s *server) serve() error { return nil } +func parseHost(host string) (string, string) { + const unixScheme = "unix://" + if strings.HasPrefix(host, unixScheme) { + return "unix", host[len(unixScheme):] + } + return "tcp", host +} + func (s *server) ListPeer(r *api.ListPeerRequest, stream api.GobgpApi_ListPeerServer) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pkg/server/grpc_server_test.go b/pkg/server/grpc_server_test.go new file mode 100644 index 00000000..6d8ee682 --- /dev/null +++ b/pkg/server/grpc_server_test.go @@ -0,0 +1,43 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseHost(t *testing.T) { + tsts := []struct { + name string + host string + expectNetwork string + expectAddr string + }{ + { + name: "schemeless tcp host defaults to tcp", + host: "127.0.0.1:50051", + expectNetwork: "tcp", + expectAddr: "127.0.0.1:50051", + }, + { + name: "schemeless with only port defaults to tcp", + host: ":50051", + expectNetwork: "tcp", + expectAddr: ":50051", + }, + { + name: "unix socket", + host: "unix:///var/run/gobgp.socket", + expectNetwork: "unix", + expectAddr: "/var/run/gobgp.socket", + }, + } + + for _, tst := range tsts { + t.Run(tst.name, func(t *testing.T) { + gotNetwork, gotAddr := parseHost(tst.host) + assert.Equal(t, tst.expectNetwork, gotNetwork) + assert.Equal(t, tst.expectAddr, gotAddr) + }) + } +} |