summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/rpcinet/socket.go
diff options
context:
space:
mode:
authorRahat Mahmood <rahat@google.com>2019-08-08 16:49:18 -0700
committergVisor bot <gvisor-bot@google.com>2019-08-08 16:50:33 -0700
commit7bfad8ebb6ce71c0fe90a1e4f5897f62809fa58b (patch)
treea579fecdddb331e141f44bcfb61dfe7bdbcba84d /pkg/sentry/socket/rpcinet/socket.go
parent13a98df49ea1b36cd21c528293b626a6a3639f0b (diff)
Return a well-defined socket address type from socket funtions.
Previously we were representing socket addresses as an interface{}, which allowed any type which could be binary.Marshal()ed to be used as a socket address. This is fine when the address is passed to userspace via the linux ABI, but is problematic when used from within the sentry such as by networking procfs files. PiperOrigin-RevId: 262460640
Diffstat (limited to 'pkg/sentry/socket/rpcinet/socket.go')
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go21
1 files changed, 12 insertions, 9 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 0e58819bc..ddb76d9d4 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
payload, se := rpcAccept(t, s.fd, peerRequested)
// Check if we need to block.
@@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
}
file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
+ family: s.family,
+ stype: s.stype,
+ protocol: s.protocol,
wq: &wq,
fd: payload.Fd,
rpcConn: s.rpcConn,
@@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
t.Kernel().RecordSocket(file)
if peerRequested {
- return fd, payload.Address.Address, payload.Address.Length, nil
+ return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
}
return fd, nil, 0, nil
@@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetPeerNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetSockNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
@@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
Fd: s.fd,
Length: uint32(dst.NumBytes()),
@@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, err
@@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
return 0, 0, nil, 0, socket.ControlMessages{}, err