summaryrefslogtreecommitdiffhomepage
path: root/vendor/github.com/containerd/ttrpc/server.go
diff options
context:
space:
mode:
authorIan Lewis <ianlewis@google.com>2018-12-04 04:55:39 -0500
committerIan Lewis <ianlewis@google.com>2018-12-04 04:55:39 -0500
commitae2250b1dd07fcdb6a2212022b75d63d26966499 (patch)
tree7c155cffdcf66ea800a0d812dee11b9a6044eabd /vendor/github.com/containerd/ttrpc/server.go
Initial commit
Diffstat (limited to 'vendor/github.com/containerd/ttrpc/server.go')
-rw-r--r--vendor/github.com/containerd/ttrpc/server.go456
1 files changed, 456 insertions, 0 deletions
diff --git a/vendor/github.com/containerd/ttrpc/server.go b/vendor/github.com/containerd/ttrpc/server.go
new file mode 100644
index 000000000..263cb4583
--- /dev/null
+++ b/vendor/github.com/containerd/ttrpc/server.go
@@ -0,0 +1,456 @@
+/*
+ Copyright The containerd Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package ttrpc
+
+import (
+ "context"
+ "io"
+ "math/rand"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+var (
+ ErrServerClosed = errors.New("ttrpc: server closed")
+)
+
+type Server struct {
+ config *serverConfig
+ services *serviceSet
+ codec codec
+
+ mu sync.Mutex
+ listeners map[net.Listener]struct{}
+ connections map[*serverConn]struct{} // all connections to current state
+ done chan struct{} // marks point at which we stop serving requests
+}
+
+func NewServer(opts ...ServerOpt) (*Server, error) {
+ config := &serverConfig{}
+ for _, opt := range opts {
+ if err := opt(config); err != nil {
+ return nil, err
+ }
+ }
+
+ return &Server{
+ config: config,
+ services: newServiceSet(),
+ done: make(chan struct{}),
+ listeners: make(map[net.Listener]struct{}),
+ connections: make(map[*serverConn]struct{}),
+ }, nil
+}
+
+func (s *Server) Register(name string, methods map[string]Method) {
+ s.services.register(name, methods)
+}
+
+func (s *Server) Serve(ctx context.Context, l net.Listener) error {
+ s.addListener(l)
+ defer s.closeListener(l)
+
+ var (
+ backoff time.Duration
+ handshaker = s.config.handshaker
+ )
+
+ if handshaker == nil {
+ handshaker = handshakerFunc(noopHandshake)
+ }
+
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ select {
+ case <-s.done:
+ return ErrServerClosed
+ default:
+ }
+
+ if terr, ok := err.(interface {
+ Temporary() bool
+ }); ok && terr.Temporary() {
+ if backoff == 0 {
+ backoff = time.Millisecond
+ } else {
+ backoff *= 2
+ }
+
+ if max := time.Second; backoff > max {
+ backoff = max
+ }
+
+ sleep := time.Duration(rand.Int63n(int64(backoff)))
+ logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
+ time.Sleep(sleep)
+ continue
+ }
+
+ return err
+ }
+
+ backoff = 0
+
+ approved, handshake, err := handshaker.Handshake(ctx, conn)
+ if err != nil {
+ logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake")
+ conn.Close()
+ continue
+ }
+
+ sc := s.newConn(approved, handshake)
+ go sc.run(ctx)
+ }
+}
+
+func (s *Server) Shutdown(ctx context.Context) error {
+ s.mu.Lock()
+ select {
+ case <-s.done:
+ default:
+ // protected by mutex
+ close(s.done)
+ }
+ lnerr := s.closeListeners()
+ s.mu.Unlock()
+
+ ticker := time.NewTicker(200 * time.Millisecond)
+ defer ticker.Stop()
+ for {
+ if s.closeIdleConns() {
+ return lnerr
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ }
+ }
+}
+
+// Close the server without waiting for active connections.
+func (s *Server) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ select {
+ case <-s.done:
+ default:
+ // protected by mutex
+ close(s.done)
+ }
+
+ err := s.closeListeners()
+ for c := range s.connections {
+ c.close()
+ delete(s.connections, c)
+ }
+
+ return err
+}
+
+func (s *Server) addListener(l net.Listener) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.listeners[l] = struct{}{}
+}
+
+func (s *Server) closeListener(l net.Listener) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.closeListenerLocked(l)
+}
+
+func (s *Server) closeListenerLocked(l net.Listener) error {
+ defer delete(s.listeners, l)
+ return l.Close()
+}
+
+func (s *Server) closeListeners() error {
+ var err error
+ for l := range s.listeners {
+ if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
+ err = cerr
+ }
+ }
+ return err
+}
+
+func (s *Server) addConnection(c *serverConn) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.connections[c] = struct{}{}
+}
+
+func (s *Server) closeIdleConns() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ quiescent := true
+ for c := range s.connections {
+ st, ok := c.getState()
+ if !ok || st != connStateIdle {
+ quiescent = false
+ continue
+ }
+ c.close()
+ delete(s.connections, c)
+ }
+ return quiescent
+}
+
+type connState int
+
+const (
+ connStateActive = iota + 1 // outstanding requests
+ connStateIdle // no requests
+ connStateClosed // closed connection
+)
+
+func (cs connState) String() string {
+ switch cs {
+ case connStateActive:
+ return "active"
+ case connStateIdle:
+ return "idle"
+ case connStateClosed:
+ return "closed"
+ default:
+ return "unknown"
+ }
+}
+
+func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
+ c := &serverConn{
+ server: s,
+ conn: conn,
+ handshake: handshake,
+ shutdown: make(chan struct{}),
+ }
+ c.setState(connStateIdle)
+ s.addConnection(c)
+ return c
+}
+
+type serverConn struct {
+ server *Server
+ conn net.Conn
+ handshake interface{} // data from handshake, not used for now
+ state atomic.Value
+
+ shutdownOnce sync.Once
+ shutdown chan struct{} // forced shutdown, used by close
+}
+
+func (c *serverConn) getState() (connState, bool) {
+ cs, ok := c.state.Load().(connState)
+ return cs, ok
+}
+
+func (c *serverConn) setState(newstate connState) {
+ c.state.Store(newstate)
+}
+
+func (c *serverConn) close() error {
+ c.shutdownOnce.Do(func() {
+ close(c.shutdown)
+ })
+
+ return nil
+}
+
+func (c *serverConn) run(sctx context.Context) {
+ type (
+ request struct {
+ id uint32
+ req *Request
+ }
+
+ response struct {
+ id uint32
+ resp *Response
+ }
+ )
+
+ var (
+ ch = newChannel(c.conn)
+ ctx, cancel = context.WithCancel(sctx)
+ active int
+ state connState = connStateIdle
+ responses = make(chan response)
+ requests = make(chan request)
+ recvErr = make(chan error, 1)
+ shutdown = c.shutdown
+ done = make(chan struct{})
+ )
+
+ defer c.conn.Close()
+ defer cancel()
+ defer close(done)
+
+ go func(recvErr chan error) {
+ defer close(recvErr)
+ sendImmediate := func(id uint32, st *status.Status) bool {
+ select {
+ case responses <- response{
+ // even though we've had an invalid stream id, we send it
+ // back on the same stream id so the client knows which
+ // stream id was bad.
+ id: id,
+ resp: &Response{
+ Status: st.Proto(),
+ },
+ }:
+ return true
+ case <-c.shutdown:
+ return false
+ case <-done:
+ return false
+ }
+ }
+
+ for {
+ select {
+ case <-c.shutdown:
+ return
+ case <-done:
+ return
+ default: // proceed
+ }
+
+ mh, p, err := ch.recv(ctx)
+ if err != nil {
+ status, ok := status.FromError(err)
+ if !ok {
+ recvErr <- err
+ return
+ }
+
+ // in this case, we send an error for that particular message
+ // when the status is defined.
+ if !sendImmediate(mh.StreamID, status) {
+ return
+ }
+
+ continue
+ }
+
+ if mh.Type != messageTypeRequest {
+ // we must ignore this for future compat.
+ continue
+ }
+
+ var req Request
+ if err := c.server.codec.Unmarshal(p, &req); err != nil {
+ ch.putmbuf(p)
+ if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
+ return
+ }
+ continue
+ }
+ ch.putmbuf(p)
+
+ if mh.StreamID%2 != 1 {
+ // enforce odd client initiated identifiers.
+ if !sendImmediate(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
+ return
+ }
+ continue
+ }
+
+ // Forward the request to the main loop. We don't wait on s.done
+ // because we have already accepted the client request.
+ select {
+ case requests <- request{
+ id: mh.StreamID,
+ req: &req,
+ }:
+ case <-done:
+ return
+ }
+ }
+ }(recvErr)
+
+ for {
+ newstate := state
+ switch {
+ case active > 0:
+ newstate = connStateActive
+ shutdown = nil
+ case active == 0:
+ newstate = connStateIdle
+ shutdown = c.shutdown // only enable this branch in idle mode
+ }
+
+ if newstate != state {
+ c.setState(newstate)
+ state = newstate
+ }
+
+ select {
+ case request := <-requests:
+ active++
+ go func(id uint32) {
+ p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
+ resp := &Response{
+ Status: status.Proto(),
+ Payload: p,
+ }
+
+ select {
+ case responses <- response{
+ id: id,
+ resp: resp,
+ }:
+ case <-done:
+ }
+ }(request.id)
+ case response := <-responses:
+ p, err := c.server.codec.Marshal(response.resp)
+ if err != nil {
+ logrus.WithError(err).Error("failed marshaling response")
+ return
+ }
+
+ if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
+ logrus.WithError(err).Error("failed sending message on channel")
+ return
+ }
+
+ active--
+ case err := <-recvErr:
+ // TODO(stevvooe): Not wildly clear what we should do in this
+ // branch. Basically, it means that we are no longer receiving
+ // requests due to a terminal error.
+ recvErr = nil // connection is now "closing"
+ if err != nil && err != io.EOF {
+ logrus.WithError(err).Error("error receiving message")
+ }
+ case <-shutdown:
+ return
+ }
+ }
+}