diff options
Diffstat (limited to 'pkg/lisafs/sock.go')
-rw-r--r-- | pkg/lisafs/sock.go | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go new file mode 100644 index 000000000..88210242f --- /dev/null +++ b/pkg/lisafs/sock.go @@ -0,0 +1,208 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lisafs + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +var ( + sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) +) + +// sockHeader is the header present in front of each message received on a UDS. +// +// +marshal +type sockHeader struct { + payloadLen uint32 + message MID + _ uint16 // Need to make struct packed. +} + +// sockCommunicator implements Communicator. This is not thread safe. +type sockCommunicator struct { + fdTracker + sock *unet.Socket + buf []byte +} + +var _ Communicator = (*sockCommunicator)(nil) + +func newSockComm(sock *unet.Socket) *sockCommunicator { + return &sockCommunicator{ + sock: sock, + buf: make([]byte, sockHeaderLen), + } +} + +func (s *sockCommunicator) FD() int { + return s.sock.FD() +} + +func (s *sockCommunicator) destroy() { + s.sock.Close() +} + +func (s *sockCommunicator) shutdown() { + if err := s.sock.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) + } +} + +func (s *sockCommunicator) resizeBuf(size uint32) { + if cap(s.buf) < int(size) { + s.buf = s.buf[:cap(s.buf)] + s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) + } else { + s.buf = s.buf[:size] + } +} + +// PayloadBuf implements Communicator.PayloadBuf. +func (s *sockCommunicator) PayloadBuf(size uint32) []byte { + s.resizeBuf(sockHeaderLen + size) + return s.buf[sockHeaderLen : sockHeaderLen+size] +} + +// SndRcvMessage implements Communicator.SndRcvMessage. +func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { + if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { + return 0, 0, err + } + + return s.rcvMsg(wantFDs) +} + +// sndPrepopulatedMsg assumes that s.buf has already been populated with +// `payloadLen` bytes of data. +func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { + header := sockHeader{payloadLen: payloadLen, message: m} + header.MarshalUnsafe(s.buf) + dataLen := sockHeaderLen + payloadLen + return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) +} + +// writeTo writes the passed iovec to the UDS and donates any passed FDs. +func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { + w := sock.Writer(true) + if len(fds) > 0 { + w.PackFDs(fds...) + } + + fdsUnpacked := false + for n := 0; n < dataLen; { + cur, err := w.WriteVec(iovec) + if err != nil { + return err + } + n += cur + + // Fast common path. + if n >= dataLen { + break + } + + // Consume iovecs. + for consumed := 0; consumed < cur; { + if len(iovec[0]) <= cur-consumed { + consumed += len(iovec[0]) + iovec = iovec[1:] + } else { + iovec[0] = iovec[0][cur-consumed:] + break + } + } + + if n > 0 && !fdsUnpacked { + // Don't resend any control message. + fdsUnpacked = true + w.UnpackFDs() + } + } + return nil +} + +// rcvMsg reads the message header and payload from the UDS. It also populates +// fds with any donated FDs. +func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { + fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) + if err != nil { + return 0, 0, err + } + for _, fd := range fds { + s.TrackFD(fd) + } + + var header sockHeader + header.UnmarshalUnsafe(s.buf) + + // No payload? We are done. + if header.payloadLen == 0 { + return header.message, 0, nil + } + + if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { + return 0, 0, err + } + + return header.message, header.payloadLen, nil +} + +// readFrom fills the passed buffer with data from the socket. It also returns +// any donated FDs. +func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { + r := sock.Reader(true) + r.EnableFDs(int(wantFDs)) + + var ( + fds []int + fdInit bool + ) + n := len(buf) + for got := 0; got < n; { + cur, err := r.ReadVec([][]byte{buf[got:]}) + + // Ignore EOF if cur > 0. + if err != nil && (err != io.EOF || cur == 0) { + r.CloseFDs() + return nil, err + } + + if !fdInit && cur > 0 { + fds, err = r.ExtractFDs() + if err != nil { + return nil, err + } + + fdInit = true + r.EnableFDs(0) + } + + got += cur + } + return fds, nil +} + +func closeFDs(fds []int) { + for _, fd := range fds { + if fd >= 0 { + unix.Close(fd) + } + } +} |