From 7bfad8ebb6ce71c0fe90a1e4f5897f62809fa58b Mon Sep 17 00:00:00 2001
From: Rahat Mahmood <rahat@google.com>
Date: Thu, 8 Aug 2019 16:49:18 -0700
Subject: Return a well-defined socket address type from socket funtions.

Previously we were representing socket addresses as an interface{},
which allowed any type which could be binary.Marshal()ed to be used as
a socket address. This is fine when the address is passed to userspace
via the linux ABI, but is problematic when used from within the sentry
such as by networking procfs files.

PiperOrigin-RevId: 262460640
---
 pkg/sentry/socket/unix/unix.go | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

(limited to 'pkg/sentry/socket/unix/unix.go')

diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 6b1f8679c..9032b7580 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -137,7 +137,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
 
 // GetPeerName implements the linux syscall getpeername(2) for sockets backed by
 // a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
 	addr, err := s.ep.GetRemoteAddress()
 	if err != nil {
 		return nil, 0, syserr.TranslateNetstackError(err)
@@ -149,7 +149,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
 
 // GetSockName implements the linux syscall getsockname(2) for sockets backed by
 // a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
 	addr, err := s.ep.GetLocalAddress()
 	if err != nil {
 		return nil, 0, syserr.TranslateNetstackError(err)
@@ -199,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
 
 // Accept implements the linux syscall accept(2) for sockets backed by
 // a transport.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
 	// Issue the accept request to get the new endpoint.
 	ep, err := s.ep.Accept()
 	if err != nil {
@@ -223,7 +223,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
 		ns.SetFlags(flags.Settable())
 	}
 
-	var addr interface{}
+	var addr linux.SockAddr
 	var addrLen uint32
 	if peerRequested {
 		// Get address of the peer.
@@ -505,7 +505,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
 
 // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
 // a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
 	trunc := flags&linux.MSG_TRUNC != 0
 	peek := flags&linux.MSG_PEEK != 0
 	dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -543,7 +543,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 	}
 	var total int64
 	if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
-		var from interface{}
+		var from linux.SockAddr
 		var fromLen uint32
 		if r.From != nil && len([]byte(r.From.Addr)) != 0 {
 			from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
@@ -578,7 +578,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 
 	for {
 		if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
-			var from interface{}
+			var from linux.SockAddr
 			var fromLen uint32
 			if r.From != nil {
 				from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
-- 
cgit v1.2.3