From 7bfad8ebb6ce71c0fe90a1e4f5897f62809fa58b Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 8 Aug 2019 16:49:18 -0700 Subject: Return a well-defined socket address type from socket funtions. Previously we were representing socket addresses as an interface{}, which allowed any type which could be binary.Marshal()ed to be used as a socket address. This is fine when the address is passed to userspace via the linux ABI, but is problematic when used from within the sentry such as by networking procfs files. PiperOrigin-RevId: 262460640 --- pkg/sentry/socket/epsocket/epsocket.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) (limited to 'pkg/sentry/socket/epsocket/epsocket.go') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 1a4442959..8cb5c823f 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -548,7 +548,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -575,7 +575,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer and write it to peer slice. @@ -1056,7 +1056,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(linux.SockAddrInet).Addr, nil + return a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1686,7 +1686,7 @@ func isLinkLocal(addr tcpip.Address) bool { } // ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { switch family { case linux.AF_UNIX: var out linux.SockAddrUnix @@ -1702,15 +1702,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // address length is the max. Abstract and empty paths always return // the full exact length. if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return out, uint32(2 + l) + return &out, uint32(2 + l) } - return out, uint32(3 + l) + return &out, uint32(3 + l) case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1726,7 +1726,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) default: return nil, 0 } @@ -1734,7 +1734,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1746,7 +1746,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1819,7 +1819,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1867,7 +1867,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if err == nil { s.updateTimestamp() } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) @@ -1942,7 +1942,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 -- cgit v1.2.3