diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/abi/linux/socket.go | 14 | ||||
-rw-r--r-- | pkg/sentry/fs/proc/net.go | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 26 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 35 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket_unsafe.go | 10 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 14 | ||||
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 21 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 38 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 14 |
10 files changed, 116 insertions, 61 deletions
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 6d22002c4..e53165622 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -267,6 +267,20 @@ type SockAddrUnix struct { Path [UnixPathMax]int8 } +// SockAddr represents a union of valid socket address types. This is logically +// equivalent to struct sockaddr. SockAddr ensures that a well-defined set of +// types can be used as socket addresses. +type SockAddr interface { + // implementsSockAddr exists purely to allow a type to indicate that they + // implement this interface. This method is a no-op and shouldn't be called. + implementsSockAddr() +} + +func (s *SockAddrInet) implementsSockAddr() {} +func (s *SockAddrInet6) implementsSockAddr() {} +func (s *SockAddrUnix) implementsSockAddr() {} +func (s *SockAddrNetlink) implementsSockAddr() {} + // Linger is struct linger, from include/linux/socket.h. type Linger struct { OnOff int32 diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 6b839685b..9adb23608 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -348,7 +348,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se // Field: local_adddress. var localAddr linux.SockAddrInet if local, _, err := sops.GetSockName(t); err == nil { - localAddr = local.(linux.SockAddrInet) + localAddr = *local.(*linux.SockAddrInet) } binary.LittleEndian.PutUint16(portBuf, localAddr.Port) fmt.Fprintf(&buf, "%08X:%04X ", @@ -358,7 +358,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se // Field: rem_address. var remoteAddr linux.SockAddrInet if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = remote.(linux.SockAddrInet) + remoteAddr = *remote.(*linux.SockAddrInet) } binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port) fmt.Fprintf(&buf, "%08X:%04X ", diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 2b03ea87c..3300f9a6b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/binary", "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 1a4442959..8cb5c823f 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -548,7 +548,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -575,7 +575,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer and write it to peer slice. @@ -1056,7 +1056,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(linux.SockAddrInet).Addr, nil + return a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1686,7 +1686,7 @@ func isLinkLocal(addr tcpip.Address) bool { } // ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { switch family { case linux.AF_UNIX: var out linux.SockAddrUnix @@ -1702,15 +1702,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // address length is the max. Abstract and empty paths always return // the full exact length. if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return out, uint32(2 + l) + return &out, uint32(2 + l) } - return out, uint32(3 + l) + return &out, uint32(3 + l) case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1726,7 +1726,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) default: return nil, 0 } @@ -1734,7 +1734,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1746,7 +1746,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1819,7 +1819,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1867,7 +1867,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if err == nil { s.updateTimestamp() } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) @@ -1942,7 +1942,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 618ea42c7..92beb1bcf 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { - var peerAddr []byte +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + var peerAddr linux.SockAddr + var peerAddrBuf []byte var peerAddrlen uint32 var peerAddrPtr *byte var peerAddrlenPtr *uint32 if peerRequested { - peerAddr = make([]byte, sizeofSockaddr) - peerAddrlen = uint32(len(peerAddr)) - peerAddrPtr = &peerAddr[0] + peerAddrBuf = make([]byte, sizeofSockaddr) + peerAddrlen = uint32(len(peerAddrBuf)) + peerAddrPtr = &peerAddrBuf[0] peerAddrlenPtr = &peerAddrlen } @@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } if peerRequested { - peerAddr = peerAddr[:peerAddrlen] + peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen]) } if syscallErr != nil { return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) @@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary @@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr []byte + var senderAddr linux.SockAddr + var senderAddrBuf []byte if senderRequested { - senderAddr = make([]byte, sizeofSockaddr) + senderAddrBuf = make([]byte, sizeofSockaddr) } var msgFlags int @@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if dsts.NumBlocks() == 1 { // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr) + return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) } iovs := iovecsFromBlockSeq(dsts) @@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Iov: &iovs[0], Iovlen: uint64(len(iovs)), } - if len(senderAddr) != 0 { - msg.Name = &senderAddr[0] - msg.Namelen = uint32(len(senderAddr)) + if len(senderAddrBuf) != 0 { + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(len(senderAddrBuf)) } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } - senderAddr = senderAddr[:msg.Namelen] + senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) return n, nil }) @@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We don't allow control messages. msgFlags &^= linux.MSG_CTRUNC - return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err) + if senderRequested { + senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) + } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index 6c69ba9c7..e69ec38c2 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -18,10 +18,12 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -91,25 +93,25 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) { } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) if errno != 0 { return nil, 0, syserr.FromError(errno) } - return addr[:addrlen], addrlen, nil + return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) if errno != 0 { return nil, 0, syserr.FromError(errno) } - return addr[:addrlen], addrlen, nil + return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) { diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81c488b29..eccbd527a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -271,7 +271,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr } // Accept implements socket.Socket.Accept. -func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Netlink sockets never support accept. return 0, nil, 0, syserr.ErrNotSupported } @@ -379,11 +379,11 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } // GetSockName implements socket.Socket.GetSockName. -func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { s.mu.Lock() defer s.mu.Unlock() - sa := linux.SockAddrNetlink{ + sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: uint32(s.portID), } @@ -391,8 +391,8 @@ func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // GetPeerName implements socket.Socket.GetPeerName. -func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { - sa := linux.SockAddrNetlink{ +func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { + sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, // TODO(b/68878065): Support non-kernel peers. For now the peer // must be the kernel. @@ -402,8 +402,8 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { - from := linux.SockAddrNetlink{ +func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + from := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, } diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 0e58819bc..ddb76d9d4 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { payload, se := rpcAccept(t, s.fd, peerRequested) // Check if we need to block. @@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, NonBlocking: flags&linux.SOCK_NONBLOCK != 0, } file := fs.NewFile(t, dirent, fileFlags, &socketOperations{ + family: s.family, + stype: s.stype, + protocol: s.protocol, wq: &wq, fd: payload.Fd, rpcConn: s.rpcConn, @@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, t.Kernel().RecordSocket(file) if peerRequested { - return fd, payload.Address.Address, payload.Address.Length, nil + return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil } return fd, nil, 0, nil @@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetPeerNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetSockNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) { @@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ Fd: s.fd, Length: uint32(dst.NumBytes()), @@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, err @@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { return 0, 0, nil, 0, socket.ControlMessages{}, err diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 30f0a3167..8c250c325 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -20,8 +20,10 @@ package socket import ( "fmt" "sync/atomic" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -52,7 +54,7 @@ type Socket interface { // Accept implements the accept4(2) linux syscall. // Returns fd, real peer address length and error. Real peer address // length is only set if len(peer) > 0. - Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) + Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) // Bind implements the bind(2) linux syscall. Bind(t *kernel.Task, sockaddr []byte) *syserr.Error @@ -73,13 +75,13 @@ type Socket interface { // // addrLen is the address length to be returned to the application, not // necessarily the actual length of the address. - GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error) + GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) // GetPeerName implements the getpeername(2) linux syscall. // // addrLen is the address length to be returned to the application, not // necessarily the actual length of the address. - GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error) + GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) // RecvMsg implements the recvmsg(2) linux syscall. // @@ -92,7 +94,7 @@ type Socket interface { // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately. // // If err != nil, the recv was not successful. - RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) + RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take // ownership of the ControlMessage on error. @@ -340,3 +342,31 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { t.Kernel().EmitUnimplementedEvent(t) } } + +// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of +// the ABI socket address types. +// +// Precondition: data must be long enough to represent a socket address of the +// given family. +func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { + switch family { + case syscall.AF_INET: + var addr linux.SockAddrInet + binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_INET6: + var addr linux.SockAddrInet6 + binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_UNIX: + var addr linux.SockAddrUnix + binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_NETLINK: + var addr linux.SockAddrNetlink + binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr) + return &addr + default: + panic(fmt.Sprintf("Unsupported socket family %v", family)) + } +} diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 6b1f8679c..9032b7580 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -137,7 +137,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -149,7 +149,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -199,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, err := s.ep.Accept() if err != nil { @@ -223,7 +223,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer. @@ -505,7 +505,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -543,7 +543,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } var total int64 if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { - var from interface{} + var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) @@ -578,7 +578,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags for { if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { - var from interface{} + var from linux.SockAddr var fromLen uint32 if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) |