From 7bfad8ebb6ce71c0fe90a1e4f5897f62809fa58b Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 8 Aug 2019 16:49:18 -0700 Subject: Return a well-defined socket address type from socket funtions. Previously we were representing socket addresses as an interface{}, which allowed any type which could be binary.Marshal()ed to be used as a socket address. This is fine when the address is passed to userspace via the linux ABI, but is problematic when used from within the sentry such as by networking procfs files. PiperOrigin-RevId: 262460640 --- pkg/sentry/socket/BUILD | 1 + pkg/sentry/socket/epsocket/epsocket.go | 26 ++++++++++---------- pkg/sentry/socket/hostinet/socket.go | 35 ++++++++++++++------------ pkg/sentry/socket/hostinet/socket_unsafe.go | 10 +++++--- pkg/sentry/socket/netlink/socket.go | 14 +++++------ pkg/sentry/socket/rpcinet/socket.go | 21 +++++++++------- pkg/sentry/socket/socket.go | 38 ++++++++++++++++++++++++++--- pkg/sentry/socket/unix/unix.go | 14 +++++------ 8 files changed, 100 insertions(+), 59 deletions(-) (limited to 'pkg/sentry/socket') diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 2b03ea87c..3300f9a6b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/binary", "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 1a4442959..8cb5c823f 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -548,7 +548,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -575,7 +575,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer and write it to peer slice. @@ -1056,7 +1056,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(linux.SockAddrInet).Addr, nil + return a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1686,7 +1686,7 @@ func isLinkLocal(addr tcpip.Address) bool { } // ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { switch family { case linux.AF_UNIX: var out linux.SockAddrUnix @@ -1702,15 +1702,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // address length is the max. Abstract and empty paths always return // the full exact length. if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return out, uint32(2 + l) + return &out, uint32(2 + l) } - return out, uint32(3 + l) + return &out, uint32(3 + l) case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1726,7 +1726,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) default: return nil, 0 } @@ -1734,7 +1734,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1746,7 +1746,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1819,7 +1819,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1867,7 +1867,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if err == nil { s.updateTimestamp() } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) @@ -1942,7 +1942,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 618ea42c7..92beb1bcf 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { - var peerAddr []byte +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + var peerAddr linux.SockAddr + var peerAddrBuf []byte var peerAddrlen uint32 var peerAddrPtr *byte var peerAddrlenPtr *uint32 if peerRequested { - peerAddr = make([]byte, sizeofSockaddr) - peerAddrlen = uint32(len(peerAddr)) - peerAddrPtr = &peerAddr[0] + peerAddrBuf = make([]byte, sizeofSockaddr) + peerAddrlen = uint32(len(peerAddrBuf)) + peerAddrPtr = &peerAddrBuf[0] peerAddrlenPtr = &peerAddrlen } @@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } if peerRequested { - peerAddr = peerAddr[:peerAddrlen] + peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen]) } if syscallErr != nil { return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) @@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary @@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr []byte + var senderAddr linux.SockAddr + var senderAddrBuf []byte if senderRequested { - senderAddr = make([]byte, sizeofSockaddr) + senderAddrBuf = make([]byte, sizeofSockaddr) } var msgFlags int @@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if dsts.NumBlocks() == 1 { // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr) + return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) } iovs := iovecsFromBlockSeq(dsts) @@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Iov: &iovs[0], Iovlen: uint64(len(iovs)), } - if len(senderAddr) != 0 { - msg.Name = &senderAddr[0] - msg.Namelen = uint32(len(senderAddr)) + if len(senderAddrBuf) != 0 { + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(len(senderAddrBuf)) } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } - senderAddr = senderAddr[:msg.Namelen] + senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) return n, nil }) @@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We don't allow control messages. msgFlags &^= linux.MSG_CTRUNC - return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err) + if senderRequested { + senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) + } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index 6c69ba9c7..e69ec38c2 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -18,10 +18,12 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -91,25 +93,25 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) { } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) if errno != 0 { return nil, 0, syserr.FromError(errno) } - return addr[:addrlen], addrlen, nil + return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) if errno != 0 { return nil, 0, syserr.FromError(errno) } - return addr[:addrlen], addrlen, nil + return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) { diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81c488b29..eccbd527a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -271,7 +271,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr } // Accept implements socket.Socket.Accept. -func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Netlink sockets never support accept. return 0, nil, 0, syserr.ErrNotSupported } @@ -379,11 +379,11 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } // GetSockName implements socket.Socket.GetSockName. -func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { s.mu.Lock() defer s.mu.Unlock() - sa := linux.SockAddrNetlink{ + sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: uint32(s.portID), } @@ -391,8 +391,8 @@ func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // GetPeerName implements socket.Socket.GetPeerName. -func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { - sa := linux.SockAddrNetlink{ +func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { + sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, // TODO(b/68878065): Support non-kernel peers. For now the peer // must be the kernel. @@ -402,8 +402,8 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { - from := linux.SockAddrNetlink{ +func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + from := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, } diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 0e58819bc..ddb76d9d4 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { payload, se := rpcAccept(t, s.fd, peerRequested) // Check if we need to block. @@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, NonBlocking: flags&linux.SOCK_NONBLOCK != 0, } file := fs.NewFile(t, dirent, fileFlags, &socketOperations{ + family: s.family, + stype: s.stype, + protocol: s.protocol, wq: &wq, fd: payload.Fd, rpcConn: s.rpcConn, @@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, t.Kernel().RecordSocket(file) if peerRequested { - return fd, payload.Address.Address, payload.Address.Length, nil + return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil } return fd, nil, 0, nil @@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetPeerNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetSockNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) { @@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ Fd: s.fd, Length: uint32(dst.NumBytes()), @@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, err @@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { return 0, 0, nil, 0, socket.ControlMessages{}, err diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 30f0a3167..8c250c325 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -20,8 +20,10 @@ package socket import ( "fmt" "sync/atomic" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -52,7 +54,7 @@ type Socket interface { // Accept implements the accept4(2) linux syscall. // Returns fd, real peer address length and error. Real peer address // length is only set if len(peer) > 0. - Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) + Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) // Bind implements the bind(2) linux syscall. Bind(t *kernel.Task, sockaddr []byte) *syserr.Error @@ -73,13 +75,13 @@ type Socket interface { // // addrLen is the address length to be returned to the application, not // necessarily the actual length of the address. - GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error) + GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) // GetPeerName implements the getpeername(2) linux syscall. // // addrLen is the address length to be returned to the application, not // necessarily the actual length of the address. - GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error) + GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) // RecvMsg implements the recvmsg(2) linux syscall. // @@ -92,7 +94,7 @@ type Socket interface { // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately. // // If err != nil, the recv was not successful. - RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) + RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error) // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take // ownership of the ControlMessage on error. @@ -340,3 +342,31 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { t.Kernel().EmitUnimplementedEvent(t) } } + +// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of +// the ABI socket address types. +// +// Precondition: data must be long enough to represent a socket address of the +// given family. +func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { + switch family { + case syscall.AF_INET: + var addr linux.SockAddrInet + binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_INET6: + var addr linux.SockAddrInet6 + binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_UNIX: + var addr linux.SockAddrUnix + binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr) + return &addr + case syscall.AF_NETLINK: + var addr linux.SockAddrNetlink + binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr) + return &addr + default: + panic(fmt.Sprintf("Unsupported socket family %v", family)) + } +} diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 6b1f8679c..9032b7580 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -137,7 +137,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -149,7 +149,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -199,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, err := s.ep.Accept() if err != nil { @@ -223,7 +223,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer. @@ -505,7 +505,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -543,7 +543,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } var total int64 if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { - var from interface{} + var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) @@ -578,7 +578,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags for { if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { - var from interface{} + var from linux.SockAddr var fromLen uint32 if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) -- cgit v1.2.3