summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/rpcinet/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/rpcinet/socket.go')
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go21
1 files changed, 12 insertions, 9 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 0e58819bc..ddb76d9d4 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
payload, se := rpcAccept(t, s.fd, peerRequested)
// Check if we need to block.
@@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
}
file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
+ family: s.family,
+ stype: s.stype,
+ protocol: s.protocol,
wq: &wq,
fd: payload.Fd,
rpcConn: s.rpcConn,
@@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
t.Kernel().RecordSocket(file)
if peerRequested {
- return fd, payload.Address.Address, payload.Address.Length, nil
+ return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
}
return fd, nil, 0, nil
@@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetPeerNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetSockNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
@@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
Fd: s.fd,
Length: uint32(dst.NumBytes()),
@@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, err
@@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
return 0, 0, nil, 0, socket.ControlMessages{}, err