diff options
Diffstat (limited to 'pkg/urpc')
-rw-r--r-- | pkg/urpc/BUILD | 23 | ||||
-rw-r--r-- | pkg/urpc/urpc.go | 636 | ||||
-rw-r--r-- | pkg/urpc/urpc_test.go | 210 |
3 files changed, 869 insertions, 0 deletions
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD new file mode 100644 index 000000000..850c34ed0 --- /dev/null +++ b/pkg/urpc/BUILD @@ -0,0 +1,23 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "urpc", + srcs = ["urpc.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/fd", + "//pkg/log", + "//pkg/sync", + "//pkg/unet", + ], +) + +go_test( + name = "urpc_test", + size = "small", + srcs = ["urpc_test.go"], + library = ":urpc", + deps = ["//pkg/unet"], +) diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go new file mode 100644 index 000000000..13b2ea314 --- /dev/null +++ b/pkg/urpc/urpc.go @@ -0,0 +1,636 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package urpc provides a minimal RPC package based on unet. +// +// RPC requests are _not_ concurrent and methods must be explicitly +// registered. However, files may be send as part of the payload. +package urpc + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "reflect" + "runtime" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/unet" +) + +// maxFiles determines the maximum file payload. +const maxFiles = 32 + +// ErrTooManyFiles is returned when too many file descriptors are mapped. +var ErrTooManyFiles = errors.New("too many files") + +// ErrUnknownMethod is returned when a method is not known. +var ErrUnknownMethod = errors.New("unknown method") + +// errStopped is an internal error indicating the server has been stopped. +var errStopped = errors.New("stopped") + +// RemoteError is an error returned by the remote invocation. +// +// This indicates that the RPC transport was correct, but that the called +// function itself returned an error. +type RemoteError struct { + // Message is the result of calling Error() on the remote error. + Message string +} + +// Error returns the remote error string. +func (r RemoteError) Error() string { + return r.Message +} + +// FilePayload may be _embedded_ in another type in order to send or receive a +// file as a result of an RPC. These are not actually serialized, rather they +// are sent via an accompanying SCM_RIGHTS message (plumbed through the unet +// package). +// +// When embedding a FilePayload in an argument struct, the argument type _must_ +// be a pointer to the struct rather than the struct type itself. This is +// because the urpc package defines pointer methods on FilePayload. +type FilePayload struct { + Files []*os.File `json:"-"` +} + +// ReleaseFD releases the FD at the specified index. +func (f *FilePayload) ReleaseFD(index int) (*fd.FD, error) { + return fd.NewFromFile(f.Files[index]) +} + +// filePayload returns the file. It may be nil. +func (f *FilePayload) filePayload() []*os.File { + return f.Files +} + +// setFilePayload sets the payload. +func (f *FilePayload) setFilePayload(fs []*os.File) { + f.Files = fs +} + +// closeAll closes a slice of files. +func closeAll(files []*os.File) { + for _, f := range files { + f.Close() + } +} + +// filePayloader is implemented only by FilePayload and will be implicitly +// implemented by types that have the FilePayload embedded. Note that there is +// no way to implement these methods other than by embedding FilePayload, due +// to the way unexported method names are mangled. +type filePayloader interface { + filePayload() []*os.File + setFilePayload([]*os.File) +} + +// clientCall is the client=>server method call on the client side. +type clientCall struct { + Method string `json:"method"` + Arg interface{} `json:"arg"` +} + +// serverCall is the client=>server method call on the server side. +type serverCall struct { + Method string `json:"method"` + Arg json.RawMessage `json:"arg"` +} + +// callResult is the server=>client method call result. +type callResult struct { + Success bool `json:"success"` + Err string `json:"err"` + Result interface{} `json:"result"` +} + +// registeredMethod is method registered with the server. +type registeredMethod struct { + // fn is the underlying function. + fn reflect.Value + + // rcvr is the receiver value. + rcvr reflect.Value + + // argType is a typed argument. + argType reflect.Type + + // resultType is also a type result. + resultType reflect.Type +} + +// clientState is client metadata. +// +// The following are valid states: +// +// idle - not processing any requests, no close request. +// processing - actively processing, no close request. +// closeRequested - actively processing, pending close. +// closed - client connection has been closed. +// +// The following transitions are possible: +// +// idle -> processing, closed +// processing -> idle, closeRequested +// closeRequested -> closed +// +type clientState int + +// See clientState. +const ( + idle clientState = iota + processing + closeRequested + closed +) + +// Server is an RPC server. +type Server struct { + // mu protects all fields, except wg. + mu sync.Mutex + + // methods is the set of server methods. + methods map[string]registeredMethod + + // clients is a map of clients. + clients map[*unet.Socket]clientState + + // wg is a wait group for all outstanding clients. + wg sync.WaitGroup + + // afterRPCCallback is called after each RPC is successfully completed. + afterRPCCallback func() +} + +// NewServer returns a new server. +func NewServer() *Server { + return NewServerWithCallback(nil) +} + +// NewServerWithCallback returns a new server, who upon completion of each RPC +// calls the given function. +func NewServerWithCallback(afterRPCCallback func()) *Server { + return &Server{ + methods: make(map[string]registeredMethod), + clients: make(map[*unet.Socket]clientState), + afterRPCCallback: afterRPCCallback, + } +} + +// Register registers the given object as an RPC receiver. +// +// This functions is the same way as the built-in RPC package, but it does not +// tolerate any object with non-conforming methods. Any non-confirming methods +// will lead to an immediate panic, instead of being skipped or an error. +// Panics will also be generated by anonymous objects and duplicate entries. +func (s *Server) Register(obj interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + + typ := reflect.TypeOf(obj) + + // If we got a pointer, deref it to the underlying object. We need this to + // obtain the name of the underlying type. + typDeref := typ + if typ.Kind() == reflect.Ptr { + typDeref = typ.Elem() + } + + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + + if typDeref.Name() == "" { + // Can't be anonymous. + panic("type not named.") + } + + prettyName := typDeref.Name() + "." + method.Name + if _, ok := s.methods[prettyName]; ok { + // Duplicate entry. + panic(fmt.Sprintf("method %s is duplicated.", prettyName)) + } + + if method.PkgPath != "" { + // Must be exported. + panic(fmt.Sprintf("method %s is not exported.", prettyName)) + } + mtype := method.Type + if mtype.NumIn() != 3 { + // Need exactly two arguments (+ receiver). + panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName)) + } + argType := mtype.In(1) + if argType.Kind() != reflect.Ptr { + // Need arg pointer. + panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName)) + } + resultType := mtype.In(2) + if resultType.Kind() != reflect.Ptr { + // Need result pointer. + panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName)) + } + if mtype.NumOut() != 1 { + // Need single return. + panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName)) + } + if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() { + // Need error return. + panic(fmt.Sprintf("method %s has non-error return value.", prettyName)) + } + + // Register the method. + s.methods[prettyName] = registeredMethod{ + fn: method.Func, + rcvr: reflect.ValueOf(obj), + argType: argType, + resultType: resultType, + } + } +} + +// lookup looks up the given method. +func (s *Server) lookup(method string) (registeredMethod, bool) { + s.mu.Lock() + defer s.mu.Unlock() + rm, ok := s.methods[method] + return rm, ok +} + +// handleOne handles a single call. +func (s *Server) handleOne(client *unet.Socket) error { + // Unmarshal the call. + var c serverCall + newFs, err := unmarshal(client, &c) + if err != nil { + // Client is dead. + return err + } + + defer func() { + if s.afterRPCCallback != nil { + s.afterRPCCallback() + } + }() + // Explicitly close all these files after the call. + // + // This is also explicitly a reference to the files after the call, + // which means they are kept open for the duration of the call. + defer closeAll(newFs) + + // Start the request. + if !s.clientBeginRequest(client) { + // Client is dead; don't process this call. + return errStopped + } + defer s.clientEndRequest(client) + + // Lookup the method. + rm, ok := s.lookup(c.Method) + if !ok { + // Try to serialize the error. + return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil) + } + + // Unmarshal the arguments now that we know the type. + na := reflect.New(rm.argType.Elem()) + if err := json.Unmarshal(c.Arg, na.Interface()); err != nil { + return marshal(client, &callResult{Err: err.Error()}, nil) + } + + // Set the file payload as an argument. + if fp, ok := na.Interface().(filePayloader); ok { + fp.setFilePayload(newFs) + } + + // Call the method. + re := reflect.New(rm.resultType.Elem()) + rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re}) + if errVal := rValues[0].Interface(); errVal != nil { + return marshal(client, &callResult{Err: errVal.(error).Error()}, nil) + } + + // Set the resulting payload. + var fs []*os.File + if fp, ok := re.Interface().(filePayloader); ok { + fs = fp.filePayload() + if len(fs) > maxFiles { + // Ugh. Send an error to the client, despite success. + return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil) + } + } + + // Marshal the result. + return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs) +} + +// clientBeginRequest begins a request. +// +// If true is returned, the request may be processed. If false is returned, +// then the server has been stopped and the request should be skipped. +func (s *Server) clientBeginRequest(client *unet.Socket) bool { + s.mu.Lock() + defer s.mu.Unlock() + switch state := s.clients[client]; state { + case idle: + // Mark as processing. + s.clients[client] = processing + return true + case closed: + // Whoops, how did this happen? Must have closed immediately + // following the deserialization. Don't let the RPC actually go + // through, since we won't be able to serialize a proper + // response. + return false + default: + // Should not happen. + panic(fmt.Sprintf("expected idle or closed, got %d", state)) + } +} + +// clientEndRequest ends a request. +func (s *Server) clientEndRequest(client *unet.Socket) { + s.mu.Lock() + defer s.mu.Unlock() + switch state := s.clients[client]; state { + case processing: + // Return to idle. + s.clients[client] = idle + case closeRequested: + // Close the connection. + client.Close() + s.clients[client] = closed + default: + // Should not happen. + panic(fmt.Sprintf("expected processing or requestClose, got %d", state)) + } +} + +// clientRegister registers a connection. +// +// See Stop for more context. +func (s *Server) clientRegister(client *unet.Socket) { + s.mu.Lock() + defer s.mu.Unlock() + s.clients[client] = idle + s.wg.Add(1) +} + +// clientUnregister unregisters and closes a connection if necessary. +// +// See Stop for more context. +func (s *Server) clientUnregister(client *unet.Socket) { + s.mu.Lock() + defer s.mu.Unlock() + switch state := s.clients[client]; state { + case idle: + // Close the connection. + client.Close() + case closed: + // Already done. + default: + // Should not happen. + panic(fmt.Sprintf("expected idle or closed, got %d", state)) + } + delete(s.clients, client) + s.wg.Done() +} + +// handleRegistered handles calls from a registered client. +func (s *Server) handleRegistered(client *unet.Socket) error { + for { + // Handle one call. + if err := s.handleOne(client); err != nil { + // Client is dead. + return err + } + } +} + +// Handle synchronously handles a single client over a connection. +func (s *Server) Handle(client *unet.Socket) error { + s.clientRegister(client) + defer s.clientUnregister(client) + return s.handleRegistered(client) +} + +// StartHandling creates a goroutine that handles a single client over a +// connection. +func (s *Server) StartHandling(client *unet.Socket) { + s.clientRegister(client) + go func() { // S/R-SAFE: out of scope + defer s.clientUnregister(client) + s.handleRegistered(client) + }() +} + +// Stop safely terminates outstanding clients. +// +// No new requests should be initiated after calling Stop. Existing clients +// will be closed after completing any pending RPCs. This method will block +// until all clients have disconnected. +func (s *Server) Stop() { + // Wait for all outstanding requests. + defer s.wg.Wait() + + // Close all known clients. + s.mu.Lock() + defer s.mu.Unlock() + for client, state := range s.clients { + switch state { + case idle: + // Close connection now. + client.Close() + s.clients[client] = closed + case processing: + // Request close when done. + s.clients[client] = closeRequested + } + } +} + +// Client is a urpc client. +type Client struct { + // mu protects all members. + // + // It also enforces single-call semantics. + mu sync.Mutex + + // Socket is the underlying socket for this client. + // + // This _must_ be provided and must be closed manually by calling + // Close. + Socket *unet.Socket +} + +// NewClient returns a new client. +func NewClient(socket *unet.Socket) *Client { + return &Client{ + Socket: socket, + } +} + +// marshal sends the given FD and json struct. +func marshal(s *unet.Socket, v interface{}, fs []*os.File) error { + // Marshal to a buffer. + data, err := json.Marshal(v) + if err != nil { + log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error()) + return err + } + + // Write to the socket. + w := s.Writer(true) + if fs != nil { + var fds []int + for _, f := range fs { + fds = append(fds, int(f.Fd())) + } + w.PackFDs(fds...) + } + + // Send. + for n := 0; n < len(data); { + cur, err := w.WriteVec([][]byte{data[n:]}) + if n == 0 && cur < len(data) { + // Don't send FDs anymore. This call is only made on + // the first successful call to WriteVec, assuming cur + // is not sufficient to fill the entire buffer. + w.PackFDs() + } + n += cur + if err != nil { + log.Warningf("urpc: error writing %v: %s", data[n:], err.Error()) + return err + } + } + + // We're done sending the fds to the client. Explicitly prevent fs from + // being GCed until here. Urpc rpcs often unlink the file to send, relying + // on the kernel to automatically delete it once the last reference is + // dropped. Until we successfully call sendmsg(2), fs may contain the last + // references to these files. Without this explicit reference to fs here, + // the go runtime is free to assume we're done with fs after the fd + // collection loop above, since it just sees us copying ints. + runtime.KeepAlive(fs) + + log.Debugf("urpc: successfully marshalled %d bytes.", len(data)) + return nil +} + +// unmarhsal receives an FD (optional) and unmarshals the given struct. +func unmarshal(s *unet.Socket, v interface{}) ([]*os.File, error) { + // Receive a single byte. + r := s.Reader(true) + r.EnableFDs(maxFiles) + firstByte := make([]byte, 1) + + // Extract any FDs that may be there. + if _, err := r.ReadVec([][]byte{firstByte}); err != nil { + return nil, err + } + fds, err := r.ExtractFDs() + if err != nil { + log.Warningf("urpc: error extracting fds: %s", err.Error()) + return nil, err + } + var fs []*os.File + for _, fd := range fds { + fs = append(fs, os.NewFile(uintptr(fd), "urpc")) + } + + // Read the rest. + d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s)) + // urpc internally decodes / re-encodes the data with interface{} as the + // intermediate type. We have to unmarshal integers to json.Number type + // instead of the default float type for those intermediate values, such + // that when they get re-encoded, their values are not printed out in + // floating-point formats such as 1e9, which could not be decoded to + // explicitly typed intergers later. + d.UseNumber() + if err := d.Decode(v); err != nil { + log.Warningf("urpc: error decoding: %s", err.Error()) + for _, f := range fs { + f.Close() + } + return nil, err + } + + // All set. + log.Debugf("urpc: unmarshal success.") + return fs, nil +} + +// Call calls a function. +func (c *Client) Call(method string, arg interface{}, result interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + // If arg is a FilePayload, not a *FilePayload, files won't actually be + // sent, so error out. + if _, ok := arg.(FilePayload); ok { + return fmt.Errorf("argument is a FilePayload, but should be a *FilePayload") + } + + // Are there files to send? + var fs []*os.File + if fp, ok := arg.(filePayloader); ok { + fs = fp.filePayload() + if len(fs) > maxFiles { + return ErrTooManyFiles + } + } + + // Marshal the data. + if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil { + return err + } + + // Wait for the response. + callR := callResult{Result: result} + newFs, err := unmarshal(c.Socket, &callR) + if err != nil { + return fmt.Errorf("urpc method %q failed: %v", method, err) + } + + // Set the file payload. + if fp, ok := result.(filePayloader); ok { + fp.setFilePayload(newFs) + } else { + closeAll(newFs) + } + + // Did an error occur? + if !callR.Success { + return RemoteError{Message: callR.Err} + } + + // All set. + return nil +} + +// Close closes the underlying socket. +// +// Further calls to the client may result in undefined behavior. +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.Socket.Close() +} diff --git a/pkg/urpc/urpc_test.go b/pkg/urpc/urpc_test.go new file mode 100644 index 000000000..c6c7ce9d4 --- /dev/null +++ b/pkg/urpc/urpc_test.go @@ -0,0 +1,210 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package urpc + +import ( + "errors" + "os" + "testing" + + "gvisor.dev/gvisor/pkg/unet" +) + +type test struct { +} + +type testArg struct { + StringArg string + IntArg int + FilePayload +} + +type testResult struct { + StringResult string + IntResult int + FilePayload +} + +func (t test) Func(a *testArg, r *testResult) error { + r.StringResult = a.StringArg + r.IntResult = a.IntArg + return nil +} + +func (t test) Err(a *testArg, r *testResult) error { + return errors.New("test error") +} + +func (t test) FailNoFile(a *testArg, r *testResult) error { + if a.Files == nil { + return errors.New("no file found") + } + + return nil +} + +func (t test) SendFile(a *testArg, r *testResult) error { + r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} + return nil +} + +func (t test) TooManyFiles(a *testArg, r *testResult) error { + for i := 0; i <= maxFiles; i++ { + r.Files = append(r.Files, os.Stdin) + } + return nil +} + +func startServer(socket *unet.Socket) { + s := NewServer() + s.Register(test{}) + s.StartHandling(socket) +} + +func testClient() (*Client, error) { + serverSock, clientSock, err := unet.SocketPair(false) + if err != nil { + return nil, err + } + startServer(serverSock) + + return NewClient(clientSock), nil +} + +func TestCall(t *testing.T) { + c, err := testClient() + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + defer c.Close() + + var r testResult + if err := c.Call("test.Func", &testArg{}, &r); err != nil { + t.Errorf("basic call failed: %v", err) + } else if r.StringResult != "" || r.IntResult != 0 { + t.Errorf("unexpected result, got %v expected zero value", r) + } + if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil { + t.Errorf("basic call failed: %v", err) + } else if r.StringResult != "hello" { + t.Errorf("unexpected result, got %v expected hello", r.StringResult) + } + if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil { + t.Errorf("basic call failed: %v", err) + } else if r.IntResult != 1 { + t.Errorf("unexpected result, got %v expected 1", r.IntResult) + } +} + +func TestUnknownMethod(t *testing.T) { + c, err := testClient() + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + defer c.Close() + + var r testResult + if err := c.Call("test.Unknown", &testArg{}, &r); err == nil { + t.Errorf("expected non-nil err, got nil") + } else if err.Error() != ErrUnknownMethod.Error() { + t.Errorf("expected test error, got %v", err) + } +} + +func TestErr(t *testing.T) { + c, err := testClient() + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + defer c.Close() + + var r testResult + if err := c.Call("test.Err", &testArg{}, &r); err == nil { + t.Errorf("expected non-nil err, got nil") + } else if err.Error() != "test error" { + t.Errorf("expected test error, got %v", err) + } +} + +func TestSendFile(t *testing.T) { + c, err := testClient() + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + defer c.Close() + + var r testResult + if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil { + t.Errorf("expected non-nil err, got nil") + } + if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil { + t.Errorf("expected nil err, got %v", err) + } +} + +func TestRecvFile(t *testing.T) { + c, err := testClient() + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + defer c.Close() + + var r testResult + if err := c.Call("test.SendFile", &testArg{}, &r); err != nil { + t.Errorf("expected nil err, got %v", err) + } + if r.Files == nil { + t.Errorf("expected file, got nil") + } +} + +func TestShutdown(t *testing.T) { + serverSock, clientSock, err := unet.SocketPair(false) + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + clientSock.Close() + + s := NewServer() + if err := s.Handle(serverSock); err == nil { + t.Errorf("expected non-nil err, got nil") + } +} + +func TestTooManyFiles(t *testing.T) { + c, err := testClient() + if err != nil { + t.Fatalf("error creating test client: %v", err) + } + defer c.Close() + + var r testResult + var a testArg + for i := 0; i <= maxFiles; i++ { + a.Files = append(a.Files, os.Stdin) + } + + // Client-side error. + if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles { + t.Errorf("expected ErrTooManyFiles, got %v", err) + } + + // Server-side error. + if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil { + t.Errorf("expected non-nil err, got nil") + } else if err.Error() != "too many files" { + t.Errorf("expected too many files, got %v", err.Error()) + } +} |