diff options
Diffstat (limited to 'pkg/unet')
-rw-r--r-- | pkg/unet/BUILD | 26 | ||||
-rw-r--r-- | pkg/unet/unet.go | 569 | ||||
-rw-r--r-- | pkg/unet/unet_test.go | 693 | ||||
-rw-r--r-- | pkg/unet/unet_unsafe.go | 287 |
4 files changed, 1575 insertions, 0 deletions
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD new file mode 100644 index 000000000..e8e40315a --- /dev/null +++ b/pkg/unet/BUILD @@ -0,0 +1,26 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "unet", + srcs = [ + "unet.go", + "unet_unsafe.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/unet", + visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/gate", + ], +) + +go_test( + name = "unet_test", + size = "small", + srcs = [ + "unet_test.go", + ], + embed = [":unet"], +) diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go new file mode 100644 index 000000000..59b6c5568 --- /dev/null +++ b/pkg/unet/unet.go @@ -0,0 +1,569 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package unet provides a minimal net package based on Unix Domain Sockets. +// +// This does no pooling, and should only be used for a limited number of +// connections in a Go process. Don't use this package for arbitrary servers. +package unet + +import ( + "errors" + "sync/atomic" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/gate" +) + +// backlog is used for the listen request. +const backlog = 16 + +// errClosing is returned by wait if the Socket is in the process of closing. +var errClosing = errors.New("Socket is closing") + +// errMessageTruncated indicates that data was lost because the provided buffer +// was too small. +var errMessageTruncated = errors.New("message truncated") + +// socketType returns the appropriate type. +func socketType(packet bool) int { + if packet { + return syscall.SOCK_SEQPACKET + } + return syscall.SOCK_STREAM +} + +// socket creates a new host socket. +func socket(packet bool) (int, error) { + // Make a new socket. + fd, err := syscall.Socket(syscall.AF_UNIX, socketType(packet), 0) + if err != nil { + return 0, err + } + + return fd, nil +} + +// eventFD returns a new event FD with initial value 0. +func eventFD() (int, error) { + f, _, e := syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0) + if e != 0 { + return -1, e + } + return int(f), nil +} + +// Socket is a connected unix domain socket. +type Socket struct { + // gate protects use of fd. + gate gate.Gate + + // fd is the bound socket. + // + // fd must be read atomically, and only remains valid if read while + // within gate. + fd int32 + + // efd is an event FD that is signaled when the socket is closing. + // + // efd is immutable and remains valid until Close/Release. + efd int + + // race is an atomic variable used to avoid triggering the race + // detector. See comment in SocketPair below. + race *int32 +} + +// NewSocket returns a socket from an existing FD. +// +// NewSocket takes ownership of fd. +func NewSocket(fd int) (*Socket, error) { + // fd must be non-blocking for non-blocking syscall.Accept in + // ServerSocket.Accept. + if err := syscall.SetNonblock(fd, true); err != nil { + return nil, err + } + + efd, err := eventFD() + if err != nil { + return nil, err + } + + return &Socket{ + fd: int32(fd), + efd: efd, + }, nil +} + +// finish completes use of s.fd by evicting any waiters, closing the gate, and +// closing the event FD. +func (s *Socket) finish() error { + // Signal any blocked or future polls. + // + // N.B. eventfd writes must be 8 bytes. + if _, err := syscall.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil { + return err + } + + // Close the gate, blocking until all FD users leave. + s.gate.Close() + + return syscall.Close(s.efd) +} + +// Close closes the socket. +func (s *Socket) Close() error { + // Set the FD in the socket to -1, to ensure that all future calls to + // FD/Release get nothing and Close calls return immediately. + fd := int(atomic.SwapInt32(&s.fd, -1)) + if fd < 0 { + // Already closed or closing. + return syscall.EBADF + } + + // Shutdown the socket to cancel any pending accepts. + s.shutdown(fd) + + if err := s.finish(); err != nil { + return err + } + + return syscall.Close(fd) +} + +// Release releases ownership of the socket FD. +// +// The returned FD is non-blocking. +// +// Any concurrent or future callers of Socket methods will receive EBADF. +func (s *Socket) Release() (int, error) { + // Set the FD in the socket to -1, to ensure that all future calls to + // FD/Release get nothing and Close calls return immediately. + fd := int(atomic.SwapInt32(&s.fd, -1)) + if fd < 0 { + // Already closed or closing. + return -1, syscall.EBADF + } + + if err := s.finish(); err != nil { + return -1, err + } + + return fd, nil +} + +// FD returns the FD for this Socket. +// +// The FD is non-blocking and must not be made blocking. +// +// N.B. os.File.Fd makes the FD blocking. Use of Release instead of FD is +// strongly preferred. +// +// The returned FD cannot be used safely if there may be concurrent callers to +// Close or Release. +// +// Use Release to take ownership of the FD. +func (s *Socket) FD() int { + return int(atomic.LoadInt32(&s.fd)) +} + +// enterFD enters the FD gate and returns the FD value. +// +// If enterFD returns ok, s.gate.Leave must be called when done with the FD. +// Callers may only block while within the gate using s.wait. +// +// The returned FD is guaranteed to remain valid until s.gate.Leave. +func (s *Socket) enterFD() (int, bool) { + if !s.gate.Enter() { + return -1, false + } + + fd := int(atomic.LoadInt32(&s.fd)) + if fd < 0 { + s.gate.Leave() + return -1, false + } + + return fd, true +} + +// SocketPair creates a pair of connected sockets. +func SocketPair(packet bool) (*Socket, *Socket, error) { + // Make a new pair. + fds, err := syscall.Socketpair(syscall.AF_UNIX, socketType(packet), 0) + if err != nil { + return nil, nil, err + } + + // race is an atomic variable used to avoid triggering the race + // detector. We have to fool TSAN into thinking there is a race + // variable between our two sockets. We only use SocketPair in tests + // anyway. + // + // NOTE: This is purely due to the fact that the raw + // syscall does not serve as a boundary for the sanitizer. + var race int32 + a, err := NewSocket(fds[0]) + if err != nil { + syscall.Close(fds[0]) + syscall.Close(fds[1]) + return nil, nil, err + } + a.race = &race + b, err := NewSocket(fds[1]) + if err != nil { + a.Close() + syscall.Close(fds[1]) + return nil, nil, err + } + b.race = &race + return a, b, nil +} + +// Connect connects to a server. +func Connect(addr string, packet bool) (*Socket, error) { + fd, err := socket(packet) + if err != nil { + return nil, err + } + + // Connect the socket. + usa := &syscall.SockaddrUnix{Name: addr} + if err := syscall.Connect(fd, usa); err != nil { + syscall.Close(fd) + return nil, err + } + + return NewSocket(fd) +} + +// ControlMessage wraps around a byte array and provides functions for parsing +// as a Unix Domain Socket control message. +type ControlMessage []byte + +// EnableFDs enables receiving FDs via control message. +// +// This guarantees only a MINIMUM number of FDs received. You may receive MORE +// than this due to the way FDs are packed. To be specific, the number of +// receivable buffers will be rounded up to the nearest even number. +// +// This must be called prior to ReadVec if you want to receive FDs. +func (c *ControlMessage) EnableFDs(count int) { + *c = make([]byte, syscall.CmsgSpace(count*4)) +} + +// ExtractFDs returns the list of FDs in the control message. +// +// Either this or CloseFDs should be used after EnableFDs. +func (c *ControlMessage) ExtractFDs() ([]int, error) { + msgs, err := syscall.ParseSocketControlMessage(*c) + if err != nil { + return nil, err + } + var fds []int + for _, msg := range msgs { + thisFds, err := syscall.ParseUnixRights(&msg) + if err != nil { + // Different control message. + return nil, err + } + for _, fd := range thisFds { + if fd >= 0 { + fds = append(fds, fd) + } + } + } + return fds, nil +} + +// CloseFDs closes the list of FDs in the control message. +// +// Either this or ExtractFDs should be used after EnableFDs. +func (c *ControlMessage) CloseFDs() { + fds, _ := c.ExtractFDs() + for _, fd := range fds { + if fd >= 0 { + syscall.Close(fd) + } + } +} + +// PackFDs packs the given list of FDs in the control message. +// +// This must be used prior to WriteVec. +func (c *ControlMessage) PackFDs(fds ...int) { + *c = ControlMessage(syscall.UnixRights(fds...)) +} + +// UnpackFDs clears the control message. +func (c *ControlMessage) UnpackFDs() { + *c = nil +} + +// SocketWriter wraps an individual send operation. +// +// The normal entrypoint is WriteVec. +type SocketWriter struct { + socket *Socket + to []byte + blocking bool + race *int32 + + ControlMessage +} + +// Writer returns a writer for this socket. +func (s *Socket) Writer(blocking bool) SocketWriter { + return SocketWriter{socket: s, blocking: blocking, race: s.race} +} + +// Write implements io.Writer.Write. +func (s *Socket) Write(p []byte) (int, error) { + r := s.Writer(true) + return r.WriteVec([][]byte{p}) +} + +// GetSockOpt gets the given socket option. +func (s *Socket) GetSockOpt(level int, name int, b []byte) (uint32, error) { + fd, ok := s.enterFD() + if !ok { + return 0, syscall.EBADF + } + defer s.gate.Leave() + + return getsockopt(fd, level, name, b) +} + +// SetSockOpt sets the given socket option. +func (s *Socket) SetSockOpt(level, name int, b []byte) error { + fd, ok := s.enterFD() + if !ok { + return syscall.EBADF + } + defer s.gate.Leave() + + return setsockopt(fd, level, name, b) +} + +// GetSockName returns the socket name. +func (s *Socket) GetSockName() ([]byte, error) { + fd, ok := s.enterFD() + if !ok { + return nil, syscall.EBADF + } + defer s.gate.Leave() + + var buf []byte + l := syscall.SizeofSockaddrAny + + for { + // If the buffer is not large enough, allocate a new one with the hint. + buf = make([]byte, l) + l, err := getsockname(fd, buf) + if err != nil { + return nil, err + } + + if l <= uint32(len(buf)) { + return buf[:l], nil + } + } +} + +// GetPeerName returns the peer name. +func (s *Socket) GetPeerName() ([]byte, error) { + fd, ok := s.enterFD() + if !ok { + return nil, syscall.EBADF + } + defer s.gate.Leave() + + var buf []byte + l := syscall.SizeofSockaddrAny + + for { + // See above. + buf = make([]byte, l) + l, err := getpeername(fd, buf) + if err != nil { + return nil, err + } + + if l <= uint32(len(buf)) { + return buf[:l], nil + } + } +} + +// GetPeerCred returns the peer's unix credentials. +func (s *Socket) GetPeerCred() (*syscall.Ucred, error) { + fd, ok := s.enterFD() + if !ok { + return nil, syscall.EBADF + } + defer s.gate.Leave() + + return syscall.GetsockoptUcred(fd, syscall.SOL_SOCKET, syscall.SO_PEERCRED) +} + +// SocketReader wraps an individual receive operation. +// +// This may be used for doing vectorized reads and/or sending additional +// control messages (e.g. FDs). The normal entrypoint is ReadVec. +// +// One of ExtractFDs or DisposeFDs must be called if EnableFDs is used. +type SocketReader struct { + socket *Socket + source []byte + blocking bool + race *int32 + + ControlMessage +} + +// Reader returns a reader for this socket. +func (s *Socket) Reader(blocking bool) SocketReader { + return SocketReader{socket: s, blocking: blocking, race: s.race} +} + +// Read implements io.Reader.Read. +func (s *Socket) Read(p []byte) (int, error) { + r := s.Reader(true) + return r.ReadVec([][]byte{p}) +} + +func (s *Socket) shutdown(fd int) error { + // Shutdown the socket to cancel any pending accepts. + return syscall.Shutdown(fd, syscall.SHUT_RDWR) +} + +// Shutdown closes the socket for read and write. +func (s *Socket) Shutdown() error { + fd, ok := s.enterFD() + if !ok { + return syscall.EBADF + } + defer s.gate.Leave() + + return s.shutdown(fd) +} + +// ServerSocket is a bound unix domain socket. +type ServerSocket struct { + socket *Socket +} + +// NewServerSocket returns a socket from an existing FD. +func NewServerSocket(fd int) (*ServerSocket, error) { + s, err := NewSocket(fd) + if err != nil { + return nil, err + } + return &ServerSocket{socket: s}, nil +} + +// Bind creates and binds a new socket. +func Bind(addr string, packet bool) (*ServerSocket, error) { + fd, err := socket(packet) + if err != nil { + return nil, err + } + + // Do the bind. + usa := &syscall.SockaddrUnix{Name: addr} + if err := syscall.Bind(fd, usa); err != nil { + syscall.Close(fd) + return nil, err + } + + return NewServerSocket(fd) +} + +// BindAndListen creates, binds and listens on a new socket. +func BindAndListen(addr string, packet bool) (*ServerSocket, error) { + s, err := Bind(addr, packet) + if err != nil { + return nil, err + } + + // Start listening. + if err := s.Listen(); err != nil { + s.Close() + return nil, err + } + + return s, nil +} + +// Listen starts listening on the socket. +func (s *ServerSocket) Listen() error { + fd, ok := s.socket.enterFD() + if !ok { + return syscall.EBADF + } + defer s.socket.gate.Leave() + + return syscall.Listen(fd, backlog) +} + +// Accept accepts a new connection. +// +// This is always blocking. +// +// Preconditions: +// * ServerSocket is listening (Listen called). +func (s *ServerSocket) Accept() (*Socket, error) { + fd, ok := s.socket.enterFD() + if !ok { + return nil, syscall.EBADF + } + defer s.socket.gate.Leave() + + for { + nfd, _, err := syscall.Accept(fd) + switch err { + case nil: + return NewSocket(nfd) + case syscall.EAGAIN: + err = s.socket.wait(false) + if err == errClosing { + err = syscall.EBADF + } + } + if err != nil { + return nil, err + } + } +} + +// Close closes the server socket. +// +// This must only be called once. +func (s *ServerSocket) Close() error { + return s.socket.Close() +} + +// FD returns the socket's file descriptor. +// +// See Socket.FD. +func (s *ServerSocket) FD() int { + return s.socket.FD() +} + +// Release releases ownership of the socket's file descriptor. +// +// See Socket.Release. +func (s *ServerSocket) Release() (int, error) { + return s.socket.Release() +} diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go new file mode 100644 index 000000000..6c546825f --- /dev/null +++ b/pkg/unet/unet_test.go @@ -0,0 +1,693 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unet + +import ( + "io/ioutil" + "os" + "path/filepath" + "reflect" + "sync" + "syscall" + "testing" + "time" +) + +func randomFilename() (string, error) { + // Return a randomly generated file in the test dir. + f, err := ioutil.TempFile("", "unet-test") + if err != nil { + return "", err + } + file := f.Name() + os.Remove(file) + f.Close() + + cwd, err := os.Getwd() + if err != nil { + return "", err + } + + // NOTE: We try to use relative path if possible. This is + // to help conforming to the unix path length limit. + if rel, err := filepath.Rel(cwd, file); err == nil { + return rel, nil + } + + return file, nil +} + +func TestConnectFailure(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + if _, err := Connect(name, false); err == nil { + t.Fatalf("connect was successful, expected err") + } +} + +func TestBindFailure(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + ss, err := BindAndListen(name, false) + if err != nil { + t.Fatalf("first bind failed, got err %v expected nil", err) + } + defer ss.Close() + + if _, err = BindAndListen(name, false); err == nil { + t.Fatalf("second bind succeeded, expected non-nil err") + } +} + +func TestMultipleAccept(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + ss, err := BindAndListen(name, false) + if err != nil { + t.Fatalf("first bind failed, got err %v expected nil", err) + } + defer ss.Close() + + // Connect backlog times asynchronously. + var wg sync.WaitGroup + defer wg.Wait() + for i := 0; i < backlog; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s, err := Connect(name, false) + if err != nil { + t.Fatalf("connect failed, got err %v expected nil", err) + } + s.Close() + }() + } + + // Accept backlog times. + for i := 0; i < backlog; i++ { + s, err := ss.Accept() + if err != nil { + t.Errorf("accept failed, got err %v expected nil", err) + continue + } + s.Close() + } +} + +func TestServerClose(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + ss, err := BindAndListen(name, false) + if err != nil { + t.Fatalf("first bind failed, got err %v expected nil", err) + } + + // Make sure the first close succeeds. + if err := ss.Close(); err != nil { + t.Fatalf("first close failed, got err %v expected nil", err) + } + + // The second one should fail. + if err := ss.Close(); err == nil { + t.Fatalf("second close succeeded, expected non-nil err") + } +} + +func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + // Bind a server. + ss, err := BindAndListen(name, packet) + if err != nil { + t.Fatalf("error binding, got %v expected nil", err) + } + defer ss.Close() + + // Accept a client. + acceptSocket := make(chan *Socket) + acceptErr := make(chan error) + go func() { + server, err := ss.Accept() + if err != nil { + acceptErr <- err + } + acceptSocket <- server + }() + + // Connect the client. + client, err := Connect(name, packet) + if err != nil { + t.Fatalf("error connecting, got %v expected nil", err) + } + + // Grab the server handle. + select { + case server := <-acceptSocket: + return server, client + case err := <-acceptErr: + t.Fatalf("accept error: %v", err) + } + panic("unreachable") +} + +func TestSendRecv(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Write on the client. + w := client.Writer(true) + if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { + t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Read on the server. + b := [][]byte{{'b'}} + r := server.Reader(true) + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + } + if b[0][0] != 'a' { + t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + } +} + +// TestSymmetric exists to assert that the two sockets received from socketPair +// are interchangeable. They should be, this just provides a basic sanity check +// by running TestSendRecv "backwards". +func TestSymmetric(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Write on the server. + w := server.Writer(true) + if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { + t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Read on the client. + b := [][]byte{{'b'}} + r := client.Reader(true) + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + } + if b[0][0] != 'a' { + t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + } +} + +func TestPacket(t *testing.T) { + server, client := socketPair(t, true) + defer server.Close() + defer client.Close() + + // Write on the client. + w := client.Writer(true) + if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { + t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Write on the client again. + w = client.Writer(true) + if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { + t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Read on the server. + // + // This should only get back a single byte, despite the buffer + // being size two. This is because it's a _packet_ buffer. + b := [][]byte{{'b', 'b'}} + r := server.Reader(true) + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + } + if b[0][0] != 'a' { + t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + } + + // Do it again. + r = server.Reader(true) + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + } + if b[0][0] != 'a' { + t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + } +} + +func TestClose(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + + // Make sure the first close succeeds. + if err := client.Close(); err != nil { + t.Fatalf("first close failed, got err %v expected nil", err) + } + + // The second one should fail. + if err := client.Close(); err == nil { + t.Fatalf("second close succeeded, expected non-nil err") + } +} + +func TestNonBlockingSend(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Try up to 1000 writes, of 1000 bytes. + blockCount := 0 + for i := 0; i < 1000; i++ { + w := client.Writer(false) + if n, err := w.WriteVec([][]byte{make([]byte, 1000)}); n != 1000 || err != nil { + if err == syscall.EWOULDBLOCK || err == syscall.EAGAIN { + // We're good. That's what we wanted. + blockCount++ + } else { + t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) + } + } + } + + if blockCount == 1000 { + // Shouldn't have _always_ blocked. + t.Fatalf("socket always blocked!") + } else if blockCount == 0 { + // Should have started blocking eventually. + t.Fatalf("socket never blocked!") + } +} + +func TestNonBlockingRecv(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + b := [][]byte{{'b'}} + r := client.Reader(false) + + // Expected to block immediately. + _, err := r.ReadVec(b) + if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { + t.Fatalf("read didn't block, got err %v expected blocking err", err) + } + + // Put some data in the pipe. + w := server.Writer(false) + if n, err := w.WriteVec(b); n != 1 || err != nil { + t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Expect it not to block. + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Expect it to return a block error again. + r = client.Reader(false) + _, err = r.ReadVec(b) + if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { + t.Fatalf("read didn't block, got err %v expected blocking err", err) + } +} + +func TestRecvVectors(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Write on the client. + w := client.Writer(true) + if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { + t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + } + + // Read on the server. + b := [][]byte{{'c'}, {'c'}} + r := server.Reader(true) + if n, err := r.ReadVec(b); n != 2 || err != nil { + t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + } + if b[0][0] != 'a' || b[1][0] != 'b' { + t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) + } +} + +func TestSendVectors(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Write on the client. + w := client.Writer(true) + if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { + t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + } + + // Read on the server. + b := [][]byte{{'c', 'c'}} + r := server.Reader(true) + if n, err := r.ReadVec(b); n != 2 || err != nil { + t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + } + if b[0][0] != 'a' || b[0][1] != 'b' { + t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) + } +} + +func TestSendFDsNotEnabled(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Write on the server. + w := server.Writer(true) + w.PackFDs(0, 1, 2) + if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { + t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + } + + // Read on the client, without enabling FDs. + b := [][]byte{{'b'}} + r := client.Reader(true) + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + } + if b[0][0] != 'a' { + t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + } + + // Make sure the FDs are not received. + fds, err := r.ExtractFDs() + if len(fds) != 0 || err != nil { + t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) + } +} + +func sendFDs(t *testing.T, s *Socket, fds []int) { + w := s.Writer(true) + w.PackFDs(fds...) + if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { + t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) + } +} + +func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { + expected := len(origFDs) + + // Count the number of FDs. + preEntries, err := ioutil.ReadDir("/proc/self/fd") + if err != nil { + t.Fatalf("can't readdir, got err %v expected nil", err) + } + + // Read on the client. + b := [][]byte{{'b'}} + r := s.Reader(true) + if enableSize >= 0 { + r.EnableFDs(enableSize) + } + if n, err := r.ReadVec(b); n != 1 || err != nil { + t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + } + if b[0][0] != 'a' { + t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + } + + // Count the new number of FDs. + postEntries, err := ioutil.ReadDir("/proc/self/fd") + if err != nil { + t.Fatalf("can't readdir, got err %v expected nil", err) + } + if len(preEntries)+expected != len(postEntries) { + t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) + } + + // Make sure the FDs are there. + fds, err := r.ExtractFDs() + if len(fds) != expected || err != nil { + t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) + } + + // Make sure they are different from the originals. + for i := 0; i < len(fds); i++ { + if fds[i] == origFDs[i] { + t.Errorf("got original fd for index %d, expected different", i) + } + } + + // Make sure they can be accessed as expected. + for i := 0; i < len(fds); i++ { + var st syscall.Stat_t + if err := syscall.Fstat(fds[i], &st); err != nil { + t.Errorf("fds[%d] can't be stated, got err %v expected nil", i, err) + } + } + + // Close them off. + r.CloseFDs() + + // Make sure the count is back to normal. + finalEntries, err := ioutil.ReadDir("/proc/self/fd") + if err != nil { + t.Fatalf("can't readdir, got err %v expected nil", err) + } + if len(finalEntries) != len(preEntries) { + t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) + } +} + +func TestFDsSingle(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + sendFDs(t, server, []int{0}) + recvFDs(t, client, 1, []int{0}) +} + +func TestFDsMultiple(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + // Basic case, multiple FDs. + sendFDs(t, server, []int{0, 1, 2}) + recvFDs(t, client, 3, []int{0, 1, 2}) +} + +// See TestSymmetric above. +func TestFDsSymmetric(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + sendFDs(t, server, []int{0, 1, 2}) + recvFDs(t, client, 3, []int{0, 1, 2}) +} + +func TestFDsReceiveLargeBuffer(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + sendFDs(t, server, []int{0}) + recvFDs(t, client, 3, []int{0}) +} + +func TestFDsReceiveSmallBuffer(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + sendFDs(t, server, []int{0, 1, 2}) + + // Per the spec, we may still receive more than the buffer. In fact, + // it'll be rounded up and we can receive two with a size one buffer. + recvFDs(t, client, 1, []int{0, 1}) +} + +func TestFDsReceiveNotEnabled(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + sendFDs(t, server, []int{0}) + recvFDs(t, client, -1, []int{}) +} + +func TestFDsReceiveSizeZero(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + sendFDs(t, server, []int{0}) + recvFDs(t, client, 0, []int{}) +} + +func TestGetPeerCred(t *testing.T) { + server, client := socketPair(t, false) + defer server.Close() + defer client.Close() + + want := &syscall.Ucred{ + Pid: int32(os.Getpid()), + Uid: uint32(os.Getuid()), + Gid: uint32(os.Getgid()), + } + + if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { + t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) + } +} + +func newClosedSocket() (*Socket, error) { + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return nil, err + } + + s, err := NewSocket(fd) + if err != nil { + syscall.Close(fd) + return nil, err + } + + return s, s.Close() +} + +func TestGetPeerCredFailure(t *testing.T) { + s, err := newClosedSocket() + if err != nil { + t.Fatalf("newClosedSocket got error %v want nil", err) + } + + want := "bad file descriptor" + if _, err := s.GetPeerCred(); err == nil || err.Error() != want { + t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) + } +} + +func TestAcceptClosed(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + ss, err := BindAndListen(name, false) + if err != nil { + t.Fatalf("bind failed, got err %v expected nil", err) + } + + if err := ss.Close(); err != nil { + t.Fatalf("close failed, got err %v expected nil", err) + } + + if _, err := ss.Accept(); err == nil { + t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + } +} + +func TestCloseAfterAcceptStart(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + ss, err := BindAndListen(name, false) + if err != nil { + t.Fatalf("bind failed, got err %v expected nil", err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + time.Sleep(50 * time.Millisecond) + if err := ss.Close(); err != nil { + t.Fatalf("close failed, got err %v expected nil", err) + } + wg.Done() + }() + + if _, err := ss.Accept(); err == nil { + t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + } + + wg.Wait() +} + +func TestReleaseAfterAcceptStart(t *testing.T) { + name, err := randomFilename() + if err != nil { + t.Fatalf("unable to generate file, got err %v expected nil", err) + } + + ss, err := BindAndListen(name, false) + if err != nil { + t.Fatalf("bind failed, got err %v expected nil", err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + time.Sleep(50 * time.Millisecond) + fd, err := ss.Release() + if err != nil { + t.Fatalf("Release failed, got err %v expected nil", err) + } + syscall.Close(fd) + wg.Done() + }() + + if _, err := ss.Accept(); err == nil { + t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + } + + wg.Wait() +} + +func TestControlMessage(t *testing.T) { + for i := 0; i <= 10; i++ { + var want []int + for j := 0; j < i; j++ { + want = append(want, i+j+1) + } + + var cm ControlMessage + cm.EnableFDs(i) + cm.PackFDs(want...) + got, err := cm.ExtractFDs() + if err != nil || !reflect.DeepEqual(got, want) { + t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) + } + } +} diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go new file mode 100644 index 000000000..fa15cf744 --- /dev/null +++ b/pkg/unet/unet_unsafe.go @@ -0,0 +1,287 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unet + +import ( + "io" + "math" + "sync/atomic" + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" +) + +// wait blocks until the socket FD is ready for reading or writing, depending +// on the value of write. +// +// Returns errClosing if the Socket is in the process of closing. +func (s *Socket) wait(write bool) error { + for { + // Checking the FD on each loop is not strictly necessary, it + // just avoids an extra poll call. + fd := atomic.LoadInt32(&s.fd) + if fd < 0 { + return errClosing + } + + events := []linux.PollFD{ + { + // The actual socket FD. + FD: fd, + Events: linux.POLLIN, + }, + { + // The eventfd, signaled when we are closing. + FD: int32(s.efd), + Events: linux.POLLIN, + }, + } + if write { + events[0].Events = linux.POLLOUT + } + + _, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&events[0])), 2, uintptr(math.MaxUint64)) + if e == syscall.EINTR { + continue + } + if e != 0 { + return e + } + + if events[1].REvents&linux.POLLIN == linux.POLLIN { + // eventfd signaled, we're closing. + return errClosing + } + + return nil + } +} + +// buildIovec builds an iovec slice from the given []byte slice. +// +// iovecs is used as an initial slice, to avoid excessive allocations. +func buildIovec(bufs [][]byte, iovecs []syscall.Iovec) ([]syscall.Iovec, int) { + var length int + for i := range bufs { + if l := len(bufs[i]); l > 0 { + iovecs = append(iovecs, syscall.Iovec{ + Base: &bufs[i][0], + Len: uint64(l), + }) + length += l + } + } + return iovecs, length +} + +// ReadVec implements vecio.Reader.ReadVec. +// +// This function is not guaranteed to read all available data, it +// returns as soon as a single recvmsg call succeeds. +func (r *SocketReader) ReadVec(bufs [][]byte) (int, error) { + iovecs, length := buildIovec(bufs, make([]syscall.Iovec, 0, 2)) + + var msg syscall.Msghdr + if len(r.source) != 0 { + msg.Name = &r.source[0] + msg.Namelen = uint32(len(r.source)) + } + + if len(r.ControlMessage) != 0 { + msg.Control = &r.ControlMessage[0] + msg.Controllen = uint64(len(r.ControlMessage)) + } + + if len(iovecs) != 0 { + msg.Iov = &iovecs[0] + msg.Iovlen = uint64(len(iovecs)) + } + + // n is the bytes received. + var n uintptr + + fd, ok := r.socket.enterFD() + if !ok { + return 0, syscall.EBADF + } + // Leave on returns below. + for { + var e syscall.Errno + + // Try a non-blocking recv first, so we don't give up the go runtime M. + n, _, e = syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_TRUNC) + if e == 0 { + break + } + if e == syscall.EINTR { + continue + } + if !r.blocking { + r.socket.gate.Leave() + return 0, e + } + if e != syscall.EAGAIN && e != syscall.EWOULDBLOCK { + r.socket.gate.Leave() + return 0, e + } + + // Wait for the socket to become readable. + err := r.socket.wait(false) + if err == errClosing { + err = syscall.EBADF + } + if err != nil { + r.socket.gate.Leave() + return 0, err + } + } + + r.socket.gate.Leave() + + if msg.Controllen < uint64(len(r.ControlMessage)) { + r.ControlMessage = r.ControlMessage[:msg.Controllen] + } + + if msg.Namelen < uint32(len(r.source)) { + r.source = r.source[:msg.Namelen] + } + + // All unet sockets are SOCK_STREAM or SOCK_SEQPACKET, both of which + // indicate that the other end is closed by returning a 0-length read + // with no error. + if n == 0 { + return 0, io.EOF + } + + if r.race != nil { + // See comments on Socket.race. + atomic.AddInt32(r.race, 1) + } + + if int(n) > length { + return length, errMessageTruncated + } + + return int(n), nil +} + +// WriteVec implements vecio.Writer.WriteVec. +// +// This function is not guaranteed to send all data, it returns +// as soon as a single sendmsg call succeeds. +func (w *SocketWriter) WriteVec(bufs [][]byte) (int, error) { + iovecs, _ := buildIovec(bufs, make([]syscall.Iovec, 0, 2)) + + if w.race != nil { + // See comments on Socket.race. + atomic.AddInt32(w.race, 1) + } + + var msg syscall.Msghdr + if len(w.to) != 0 { + msg.Name = &w.to[0] + msg.Namelen = uint32(len(w.to)) + } + + if len(w.ControlMessage) != 0 { + msg.Control = &w.ControlMessage[0] + msg.Controllen = uint64(len(w.ControlMessage)) + } + + if len(iovecs) > 0 { + msg.Iov = &iovecs[0] + msg.Iovlen = uint64(len(iovecs)) + } + + fd, ok := w.socket.enterFD() + if !ok { + return 0, syscall.EBADF + } + // Leave on returns below. + for { + // Try a non-blocking send first, so we don't give up the go runtime M. + n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL) + if e == 0 { + w.socket.gate.Leave() + return int(n), nil + } + if e == syscall.EINTR { + continue + } + if !w.blocking { + w.socket.gate.Leave() + return 0, e + } + if e != syscall.EAGAIN && e != syscall.EWOULDBLOCK { + w.socket.gate.Leave() + return 0, e + } + + // Wait for the socket to become writeable. + err := w.socket.wait(true) + if err == errClosing { + err = syscall.EBADF + } + if err != nil { + w.socket.gate.Leave() + return 0, err + } + } + // Unreachable, no s.gate.Leave needed. +} + +// getsockopt issues a getsockopt syscall. +func getsockopt(fd int, level int, optname int, buf []byte) (uint32, error) { + l := uint32(len(buf)) + _, _, e := syscall.RawSyscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(optname), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l)), 0) + if e != 0 { + return 0, e + } + + return l, nil +} + +// setsockopt issues a setsockopt syscall. +func setsockopt(fd int, level int, optname int, buf []byte) error { + _, _, e := syscall.RawSyscall6(syscall.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(optname), uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)), 0) + if e != 0 { + return e + } + + return nil +} + +// getsockname issues a getsockname syscall. +func getsockname(fd int, buf []byte) (uint32, error) { + l := uint32(len(buf)) + _, _, e := syscall.RawSyscall(syscall.SYS_GETSOCKNAME, uintptr(fd), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l))) + if e != 0 { + return 0, e + } + + return l, nil +} + +// getpeername issues a getpeername syscall. +func getpeername(fd int, buf []byte) (uint32, error) { + l := uint32(len(buf)) + _, _, e := syscall.RawSyscall(syscall.SYS_GETPEERNAME, uintptr(fd), uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&l))) + if e != 0 { + return 0, e + } + + return l, nil +} |