diff options
Diffstat (limited to 'pkg/sentry/socket/rpcinet/socket.go')
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 0e58819bc..ddb76d9d4 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { payload, se := rpcAccept(t, s.fd, peerRequested) // Check if we need to block. @@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, NonBlocking: flags&linux.SOCK_NONBLOCK != 0, } file := fs.NewFile(t, dirent, fileFlags, &socketOperations{ + family: s.family, + stype: s.stype, + protocol: s.protocol, wq: &wq, fd: payload.Fd, rpcConn: s.rpcConn, @@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, t.Kernel().RecordSocket(file) if peerRequested { - return fd, payload.Address.Address, payload.Address.Length, nil + return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil } return fd, nil, 0, nil @@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetPeerNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */) <-c @@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy } addr := res.(*pb.GetSockNameResponse_Address).Address - return addr.Address, addr.Length, nil + return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil } func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) { @@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ Fd: s.fd, Length: uint32(dst.NumBytes()), @@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, err @@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } c := s.extractControlMessages(res) - return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e) + return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) } if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { return 0, 0, nil, 0, socket.ControlMessages{}, err |